diff --git a/examples/model-conversion/Makefile b/examples/model-conversion/Makefile index f0867cfe46c3a..edbe4df020429 100644 --- a/examples/model-conversion/Makefile +++ b/examples/model-conversion/Makefile @@ -119,17 +119,30 @@ embedding-convert-model: embedding-run-original-model: $(call validate_embedding_model_path,embedding-run-original-model) @EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \ + USE_SENTENCE_TRANSFORMERS="$(USE_SENTENCE_TRANSFORMERS)" \ ./scripts/embedding/run-original-model.py \ - $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") + $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \ + $(if $(USE_SENTENCE_TRANSFORMERS),--use-sentence-transformers) + +embedding-run-original-model-st: USE_SENTENCE_TRANSFORMERS=1 +embedding-run-original-model-st: embedding-run-original-model embedding-run-converted-model: @./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \ - $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") + $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \ + $(if $(USE_POOLING),--pooling) + +embedding-run-converted-model-st: USE_POOLING=1 +embedding-run-converted-model-st: embedding-run-converted-model embedding-verify-logits: embedding-run-original-model embedding-run-converted-model @./scripts/embedding/compare-embeddings-logits.sh \ $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") +embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model-st + @./scripts/embedding/compare-embeddings-logits.sh \ + $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") + embedding-inspect-original-model: $(call validate_embedding_model_path,embedding-inspect-original-model) @EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH} diff --git a/examples/model-conversion/README.md b/examples/model-conversion/README.md index e95e05cd377cc..05d95d588bae7 100644 --- a/examples/model-conversion/README.md +++ b/examples/model-conversion/README.md @@ -189,6 +189,23 @@ This command will save two files to the `data` directory, one is a binary file containing logits which will be used for comparison with the converted model, and the other is a text file which allows for manual visual inspection. +#### Using SentenceTransformer with numbered layers +For models that have numbered SentenceTransformer layers (01_Pooling, 02_Dense, +03_Dense, 04_Normalize), use the `-st` targets to apply all these layers: + +```console +# Run original model with SentenceTransformer (applies all numbered layers) +(venv) $ make embedding-run-original-model-st + +# Run converted model with pooling enabled +(venv) $ make embedding-run-converted-model-st +``` + +This will use the SentenceTransformer library to load and run the model, which +automatically applies all the numbered layers in the correct order. This is +particularly useful when comparing with models that should include these +additional transformation layers beyond just the base model output. + ### Model conversion After updates have been made to [gguf-py](../../gguf-py) to add support for the new model the model can be converted to GGUF format using the following command: @@ -208,6 +225,13 @@ was done manually in the previous steps) and compare the logits: (venv) $ make embedding-verify-logits ``` +For models with SentenceTransformer layers, use the `-st` verification target: +```console +(venv) $ make embedding-verify-logits-st +``` +This convenience target automatically runs both the original model with SentenceTransformer +and the converted model with pooling enabled, then compares the results. + ### llama-server verification To verify that the converted model works with llama-server, the following command can be used: diff --git a/examples/model-conversion/logits.cpp b/examples/model-conversion/logits.cpp index 6dc334189f4be..bbd095e6034cc 100644 --- a/examples/model-conversion/logits.cpp +++ b/examples/model-conversion/logits.cpp @@ -1,4 +1,7 @@ #include "llama.h" +#include "common.h" + + #include #include #include @@ -8,7 +11,10 @@ static void print_usage(int, char ** argv) { printf("\nexample usage:\n"); - printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [prompt]\n", argv[0]); + printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm ] [prompt]\n", argv[0]); + printf("\n"); + printf(" -embd-norm: normalization type for pooled embeddings (default: 2)\n"); + printf(" -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n"); printf("\n"); } @@ -17,6 +23,8 @@ int main(int argc, char ** argv) { std::string prompt = "Hello, my name is"; int ngl = 0; bool embedding_mode = false; + bool pooling_enabled = false; + int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) { int i = 1; @@ -41,9 +49,13 @@ int main(int argc, char ** argv) { return 1; } } else if (strcmp(argv[i], "-embd-mode") == 0) { + embedding_mode = true; + } else if (strcmp(argv[i], "-pooling") == 0) { + pooling_enabled = true; + } else if (strcmp(argv[i], "-embd-norm") == 0) { if (i + 1 < argc) { try { - embedding_mode = true; + embd_norm = std::stoi(argv[++i]); } catch (...) { print_usage(argc, argv); return 1; @@ -112,7 +124,7 @@ int main(int argc, char ** argv) { ctx_params.no_perf = false; if (embedding_mode) { ctx_params.embeddings = true; - ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE; + ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE; ctx_params.n_ubatch = ctx_params.n_batch; } @@ -143,17 +155,27 @@ int main(int argc, char ** argv) { return 1; } - float * logits; - int n_logits; + float * data_ptr; + int data_size; const char * type; + std::vector embd_out; if (embedding_mode) { - logits = llama_get_embeddings(ctx); - n_logits = llama_model_n_embd(model) * batch.n_tokens; + const int n_embd = llama_model_n_embd(model); + const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens; + const int n_embeddings = n_embd * n_embd_count; + float * embeddings; type = "-embeddings"; - const int n_embd = llama_model_n_embd(model); - const int n_embd_count = batch.n_tokens; + if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) { + embeddings = llama_get_embeddings_seq(ctx, 0); + embd_out.resize(n_embeddings); + printf("Normalizing embeddings using norm: %d\n", embd_norm); + common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm); + embeddings = embd_out.data(); + } else { + embeddings = llama_get_embeddings(ctx); + } printf("Embedding dimension: %d\n", n_embd); printf("\n"); @@ -164,7 +186,7 @@ int main(int argc, char ** argv) { // Print first 3 values for (int i = 0; i < 3 && i < n_embd; i++) { - printf("%9.6f ", logits[j * n_embd + i]); + printf("%9.6f ", embeddings[j * n_embd + i]); } printf(" ... "); @@ -172,7 +194,7 @@ int main(int argc, char ** argv) { // Print last 3 values for (int i = n_embd - 3; i < n_embd; i++) { if (i >= 0) { - printf("%9.6f ", logits[j * n_embd + i]); + printf("%9.6f ", embeddings[j * n_embd + i]); } } @@ -180,27 +202,33 @@ int main(int argc, char ** argv) { } printf("\n"); - printf("Embeddings size: %d\n", n_logits); + printf("Embeddings size: %d\n", n_embeddings); + + data_ptr = embeddings; + data_size = n_embeddings; } else { - logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - n_logits = llama_vocab_n_tokens(vocab); + float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + const int n_logits = llama_vocab_n_tokens(vocab); type = ""; printf("Vocab size: %d\n", n_logits); + + data_ptr = logits; + data_size = n_logits; } std::filesystem::create_directory("data"); - // Save logits to binary file + // Save data to binary file char bin_filename[512]; snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type); - printf("Saving logits to %s\n", bin_filename); + printf("Saving data to %s\n", bin_filename); FILE * f = fopen(bin_filename, "wb"); if (f == NULL) { fprintf(stderr, "%s: error: failed to open binary output file\n", __func__); return 1; } - fwrite(logits, sizeof(float), n_logits, f); + fwrite(data_ptr, sizeof(float), data_size, f); fclose(f); // Also save as text for debugging @@ -211,27 +239,27 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: error: failed to open text output file\n", __func__); return 1; } - for (int i = 0; i < n_logits; i++) { - fprintf(f, "%d: %.6f\n", i, logits[i]); + for (int i = 0; i < data_size; i++) { + fprintf(f, "%d: %.6f\n", i, data_ptr[i]); } fclose(f); if (!embedding_mode) { printf("First 10 logits: "); - for (int i = 0; i < 10 && i < n_logits; i++) { - printf("%.6f ", logits[i]); + for (int i = 0; i < 10 && i < data_size; i++) { + printf("%.6f ", data_ptr[i]); } printf("\n"); printf("Last 10 logits: "); - for (int i = n_logits - 10; i < n_logits; i++) { - if (i >= 0) printf("%.6f ", logits[i]); + for (int i = data_size - 10; i < data_size; i++) { + if (i >= 0) printf("%.6f ", data_ptr[i]); } printf("\n\n"); } - printf("Logits saved to %s\n", bin_filename); - printf("Logits saved to %s\n", txt_filename); + printf("Data saved to %s\n", bin_filename); + printf("Data saved to %s\n", txt_filename); llama_free(ctx); llama_model_free(model); diff --git a/examples/model-conversion/requirements.txt b/examples/model-conversion/requirements.txt index ac9f69e10bcc9..229b2ec75b75b 100644 --- a/examples/model-conversion/requirements.txt +++ b/examples/model-conversion/requirements.txt @@ -4,3 +4,4 @@ torchvision transformers huggingface-hub accelerate +sentence-transformers diff --git a/examples/model-conversion/scripts/embedding/run-converted-model.sh b/examples/model-conversion/scripts/embedding/run-converted-model.sh index f3e2676632070..0f490e6c3b20a 100755 --- a/examples/model-conversion/scripts/embedding/run-converted-model.sh +++ b/examples/model-conversion/scripts/embedding/run-converted-model.sh @@ -5,6 +5,7 @@ set -e # Parse command line arguments CONVERTED_MODEL="" PROMPTS_FILE="" +USE_POOLING="" while [[ $# -gt 0 ]]; do case $1 in @@ -12,6 +13,10 @@ while [[ $# -gt 0 ]]; do PROMPTS_FILE="$2" shift 2 ;; + --pooling) + USE_POOLING="1" + shift + ;; *) if [ -z "$CONVERTED_MODEL" ]; then CONVERTED_MODEL="$1" @@ -47,4 +52,8 @@ echo $CONVERTED_MODEL cmake --build ../../build --target llama-logits -j8 # TODO: update logits.cpp to accept a --file/-f option for the prompt -../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT" +if [ -n "$USE_POOLING" ]; then + ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT" +else + ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT" +fi diff --git a/examples/model-conversion/scripts/embedding/run-original-model.py b/examples/model-conversion/scripts/embedding/run-original-model.py index 4a3e162413fa6..640e200a97dc3 100755 --- a/examples/model-conversion/scripts/embedding/run-original-model.py +++ b/examples/model-conversion/scripts/embedding/run-original-model.py @@ -14,6 +14,8 @@ parser = argparse.ArgumentParser(description='Process model with specified path') parser.add_argument('--model-path', '-m', help='Path to the model') parser.add_argument('--prompts-file', '-p', help='Path to file containing prompts (one per line)') +parser.add_argument('--use-sentence-transformers', action='store_true', + help='Use SentenceTransformer to apply all numbered layers (01_Pooling, 02_Dense, 03_Dense, 04_Normalize)') args = parser.parse_args() def read_prompt_from_file(file_path): @@ -31,41 +33,52 @@ def read_prompt_from_file(file_path): if model_path is None: parser.error("Model path must be specified either via --model-path argument or EMBEDDING_MODEL_PATH environment variable") -tokenizer = AutoTokenizer.from_pretrained(model_path) +# Determine if we should use SentenceTransformer +use_sentence_transformers = args.use_sentence_transformers or os.environ.get('USE_SENTENCE_TRANSFORMERS', '').lower() in ('1', 'true', 'yes') -config = AutoConfig.from_pretrained(model_path) - -# This can be used to override the sliding window size for manual testing. This -# can be useful to verify the sliding window attention mask in the original model -# and compare it with the converted .gguf model. -if hasattr(config, 'sliding_window'): - original_sliding_window = config.sliding_window - #original_sliding_window = 6 - print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}") - -print(f"Using unreleased model: {unreleased_model_name}") -if unreleased_model_name: - model_name_lower = unreleased_model_name.lower() - unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" - class_name = f"{unreleased_model_name}Model" - print(f"Importing unreleased model module: {unreleased_module_path}") - - try: - model_class = getattr(importlib.import_module(unreleased_module_path), class_name) - model = model_class.from_pretrained(model_path, config=config) - except (ImportError, AttributeError) as e: - print(f"Failed to import or load model: {e}") - exit(1) +if use_sentence_transformers: + from sentence_transformers import SentenceTransformer + print("Using SentenceTransformer to apply all numbered layers") + model = SentenceTransformer(model_path) + tokenizer = model.tokenizer + config = model[0].auto_model.config # type: ignore else: - model = AutoModel.from_pretrained(model_path, config=config) -print(f"Model class: {type(model)}") -print(f"Model file: {type(model).__module__}") + tokenizer = AutoTokenizer.from_pretrained(model_path) + + config = AutoConfig.from_pretrained(model_path) + + # This can be used to override the sliding window size for manual testing. This + # can be useful to verify the sliding window attention mask in the original model + # and compare it with the converted .gguf model. + if hasattr(config, 'sliding_window'): + original_sliding_window = config.sliding_window + #original_sliding_window = 6 + print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}") + + print(f"Using unreleased model: {unreleased_model_name}") + if unreleased_model_name: + model_name_lower = unreleased_model_name.lower() + unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + class_name = f"{unreleased_model_name}Model" + print(f"Importing unreleased model module: {unreleased_module_path}") + + try: + model_class = getattr(importlib.import_module(unreleased_module_path), class_name) + model = model_class.from_pretrained(model_path, config=config) + except (ImportError, AttributeError) as e: + print(f"Failed to import or load model: {e}") + exit(1) + else: + model = AutoModel.from_pretrained(model_path, config=config) + print(f"Model class: {type(model)}") + print(f"Model file: {type(model).__module__}") # Verify the model is using the correct sliding window -if hasattr(model.config, 'sliding_window'): - print(f"Model's sliding_window: {model.config.sliding_window}") -else: - print("Model config does not have sliding_window attribute") +if not use_sentence_transformers: + if hasattr(model.config, 'sliding_window'): # type: ignore + print(f"Model's sliding_window: {model.config.sliding_window}") # type: ignore + else: + print("Model config does not have sliding_window attribute") model_name = os.path.basename(model_path) @@ -75,34 +88,56 @@ def read_prompt_from_file(file_path): else: texts = ["Hello world today"] -encoded = tokenizer( - texts, - padding=True, - truncation=True, - return_tensors="pt" -) - -tokens = encoded['input_ids'][0] -token_strings = tokenizer.convert_ids_to_tokens(tokens) -for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): - print(f"{token_id:6d} -> '{token_str}'") - with torch.no_grad(): - outputs = model(**encoded) - hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size] - - # Extract embeddings for each token (matching LLAMA_POOLING_TYPE_NONE behavior) - all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size] - - print(f"Hidden states shape: {hidden_states.shape}") - print(f"All embeddings shape: {all_embeddings.shape}") - print(f"Embedding dimension: {all_embeddings.shape[1]}") - - # Print embeddings exactly like embedding.cpp does for LLAMA_POOLING_TYPE_NONE - n_embd = all_embeddings.shape[1] - n_embd_count = all_embeddings.shape[0] - - print() # Empty line to match C++ output + if use_sentence_transformers: + embeddings = model.encode(texts, convert_to_numpy=True) + all_embeddings = embeddings # Shape: [batch_size, hidden_size] + + encoded = tokenizer( + texts, + padding=True, + truncation=True, + return_tensors="pt" + ) + tokens = encoded['input_ids'][0] + token_strings = tokenizer.convert_ids_to_tokens(tokens) + for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): + print(f"{token_id:6d} -> '{token_str}'") + + print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}") + print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}") # type: ignore + else: + # Standard approach: use base model output only + encoded = tokenizer( + texts, + padding=True, + truncation=True, + return_tensors="pt" + ) + + tokens = encoded['input_ids'][0] + token_strings = tokenizer.convert_ids_to_tokens(tokens) + for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): + print(f"{token_id:6d} -> '{token_str}'") + + outputs = model(**encoded) + hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size] + + all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size] + + print(f"Hidden states shape: {hidden_states.shape}") + print(f"All embeddings shape: {all_embeddings.shape}") + print(f"Embedding dimension: {all_embeddings.shape[1]}") + + if len(all_embeddings.shape) == 1: + n_embd = all_embeddings.shape[0] # type: ignore + n_embd_count = 1 + all_embeddings = all_embeddings.reshape(1, -1) + else: + n_embd = all_embeddings.shape[1] # type: ignore + n_embd_count = all_embeddings.shape[0] # type: ignore + + print() for j in range(n_embd_count): embedding = all_embeddings[j] @@ -120,29 +155,23 @@ def read_prompt_from_file(file_path): print() # New line - print() # Final empty line to match C++ output + print() data_dir = Path("data") data_dir.mkdir(exist_ok=True) bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin" txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt" - # Save all embeddings flattened (matching what embedding.cpp would save if it did) flattened_embeddings = all_embeddings.flatten() flattened_embeddings.astype(np.float32).tofile(bin_filename) with open(txt_filename, "w") as f: - f.write(f"# Model class: {model_name}\n") - f.write(f"# Tokens: {token_strings}\n") - f.write(f"# Shape: {all_embeddings.shape}\n") - f.write(f"# n_embd_count: {n_embd_count}, n_embd: {n_embd}\n\n") - + idx = 0 for j in range(n_embd_count): - f.write(f"# Token {j} ({token_strings[j]}):\n") - for i, value in enumerate(all_embeddings[j]): - f.write(f"{j}_{i}: {value:.6f}\n") - f.write("\n") - print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} tokens × {n_embd} dimensions)") + for value in all_embeddings[j]: + f.write(f"{idx}: {value:.6f}\n") + idx += 1 + print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)") print("") print(f"Saved bin embeddings to: {bin_filename}") print(f"Saved txt embeddings to: {txt_filename}") diff --git a/examples/model-conversion/scripts/utils/semantic_check.py b/examples/model-conversion/scripts/utils/semantic_check.py index 7fd417bceaa8b..2ac8b6b7b42cb 100644 --- a/examples/model-conversion/scripts/utils/semantic_check.py +++ b/examples/model-conversion/scripts/utils/semantic_check.py @@ -35,7 +35,11 @@ def cosine_similarity(a, b=None): def load_embeddings_from_file(filename, n_tokens, n_embd): embeddings = np.fromfile(filename, dtype=np.float32) - return embeddings.reshape(n_tokens, n_embd) + # Check if this is pooled (single embedding) or per-token embeddings + if len(embeddings) == n_embd: + return embeddings.reshape(1, n_embd) + else: + return embeddings.reshape(n_tokens, n_embd) def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt): np.set_printoptions(suppress=True, precision=6) @@ -48,58 +52,83 @@ def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt): print(f"Embeddings shape: Python {python_emb.shape}, llama.cpp {cpp_emb.shape}") n_tokens = len(tokens) + is_pooled = python_emb.shape[0] == 1 + + if is_pooled: + print(f"\n[Pooled Embeddings Mode - comparing single sentence embeddings]") - # 1. Direct embedding comparison - print(f"\n1. Raw Embedding Magnitude Comparison:") - # Check if the distance of each token embedding from the origin and compare - # if the vectors are on the same "sphere". This does not tell us about - # direction (meaning of the token embedding), just magnitude. - for i in range(n_tokens): - py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings - cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings + # 1. Direct embedding comparison for pooled embeddings + print(f"\n1. Raw Embedding Magnitude Comparison:") + py_mag = np.linalg.norm(python_emb[0]) + cpp_mag = np.linalg.norm(cpp_emb[0]) ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf') - print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}") - - # 2. Cosine similarity between tokens within each model - # Here we check the direction of token embeddings to see if the have the - # same meaning (similarity). This is done by calculating cosine similarity - # of a pair of token embeddings within each model. - print(f"\n2. Within-Model Token Similarities:") - print(" Python model:") - for i in range(n_tokens): - for j in range(i+1, n_tokens): - sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0] - print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") - - print(" llama.cpp model:") - for i in range(n_tokens): - for j in range(i+1, n_tokens): - sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0] - print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") - - # 3. Cross-model similarity (same token position) - print(f"\n3. Cross-Model Same-Token Similarities:") - for i in range(n_tokens): - sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] - print(f" Token {i} ({tokens[i]}): {sim:.4f}") - - # 4. Similarity matrix comparison - print(f"\n4. Similarity Matrix Differences:") - py_sim_matrix = cosine_similarity(python_emb) - cpp_sim_matrix = cosine_similarity(cpp_emb) - diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix) - - print(f" Max difference: {np.max(diff_matrix):.4f}") - print(f" Mean difference: {np.mean(diff_matrix):.4f}") - print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}") - - return { - 'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)], - 'similarity_matrix_diff': diff_matrix, - 'max_diff': np.max(diff_matrix), - 'mean_diff': np.mean(diff_matrix), - 'rms_diff': np.sqrt(np.mean(diff_matrix**2)) - } + print(f" Pooled embedding: Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}") + + # 2. Cross-model similarity for pooled embeddings + print(f"\n2. Cross-Model Pooled Embedding Similarity:") + sim = cosine_similarity([python_emb[0]], [cpp_emb[0]])[0][0] + print(f" Cosine similarity: {sim:.6f}") + + return { + 'cross_model_similarities': [sim], + 'similarity_matrix_diff': np.array([[0.0]]), + 'max_diff': 0.0, + 'mean_diff': 0.0, + 'rms_diff': 0.0 + } + else: + # Original per-token comparison logic + # 1. Direct embedding comparison + print(f"\n1. Raw Embedding Magnitude Comparison:") + # Check if the distance of each token embedding from the origin and compare + # if the vectors are on the same "sphere". This does not tell us about + # direction (meaning of the token embedding), just magnitude. + for i in range(n_tokens): + py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings + cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings + ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf') + print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}") + + # 2. Cosine similarity between tokens within each model + # Here we check the direction of token embeddings to see if the have the + # same meaning (similarity). This is done by calculating cosine similarity + # of a pair of token embeddings within each model. + print(f"\n2. Within-Model Token Similarities:") + print(" Python model:") + for i in range(n_tokens): + for j in range(i+1, n_tokens): + sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0] + print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") + + print(" llama.cpp model:") + for i in range(n_tokens): + for j in range(i+1, n_tokens): + sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0] + print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") + + # 3. Cross-model similarity (same token position) + print(f"\n3. Cross-Model Same-Token Similarities:") + for i in range(n_tokens): + sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] + print(f" Token {i} ({tokens[i]}): {sim:.4f}") + + # 4. Similarity matrix comparison + print(f"\n4. Similarity Matrix Differences:") + py_sim_matrix = cosine_similarity(python_emb) + cpp_sim_matrix = cosine_similarity(cpp_emb) + diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix) + + print(f" Max difference: {np.max(diff_matrix):.4f}") + print(f" Mean difference: {np.mean(diff_matrix):.4f}") + print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}") + + return { + 'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)], + 'similarity_matrix_diff': diff_matrix, + 'max_diff': np.max(diff_matrix), + 'mean_diff': np.mean(diff_matrix), + 'rms_diff': np.sqrt(np.mean(diff_matrix**2)) + } def read_prompt_from_file(file_path): try: diff --git a/requirements/requirements-all.txt b/requirements/requirements-all.txt index 56b6752ac0645..6c6bea9490b4b 100644 --- a/requirements/requirements-all.txt +++ b/requirements/requirements-all.txt @@ -14,3 +14,5 @@ -r ./requirements-tool_bench.txt -r ./requirements-gguf_editor_gui.txt + +-r ../examples/model-conversion/requirements.txt