Skip to content

Fix MTEBEvaluator: device mapping, padding-free inference, last-token pooling, L2 normalization#2415

Merged
jambayk merged 15 commits intomainfrom
natke/mteb-device-fix
Apr 16, 2026
Merged

Fix MTEBEvaluator: device mapping, padding-free inference, last-token pooling, L2 normalization#2415
jambayk merged 15 commits intomainfrom
natke/mteb-device-fix

Conversation

@natke
Copy link
Copy Markdown
Contributor

@natke natke commented Apr 14, 2026

Fixes several issues in the MTEBEvaluator for embedding model evaluation:

Device mapping

Maps Olive's Device.GPU ("gpu") to PyTorch's "cuda" when initializing SentenceTransformer in the HF evaluation path. Also handles indexed devices (e.g. gpu:0cuda:0).

Padding-free inference for GenAI

GenAI's Generator does not accept an attention_mask, so padded batches produce contaminated hidden states via self-attention to padding tokens. Fix: process each sentence individually with only its real tokens, eliminating padding entirely.

Last-token pooling

Replaced mean pooling with last-token pooling in the GenAI and ORT wrappers to match models like Qwen3-Embedding that use pooling_mode_lasttoken=True.

L2 normalization

Added L2 normalization after pooling in the base encode() method, matching the 2_Normalize module in the SentenceTransformer pipeline.

Results

These fixes close the score gap between HF and GenAI evaluation:

  • Before: HF 0.785 vs GenAI 0.651 (STS17 main_score)
  • After: HF 0.785 vs GenAI 0.785

natke and others added 7 commits April 14, 2026 12:47
…nsformer

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Qwen3-Embedding uses last-token pooling (not mean pooling) and L2
normalization, matching its SentenceTransformer pipeline config:
- pooling_mode_lasttoken: true
- 2_Normalize module

This fixes the ~17% score drop between HF and exported model evaluation.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Last-token pooling made scores worse (0.378 vs 0.651 with mean pooling),
likely due to GenAI hidden_states not aligning with HF tokenizer
attention_mask positions. Reverting pooling to mean while keeping L2
normalization which should still improve scores.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Temporary debug logging — remove before merge.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
GenAI hidden_states shape matches input_ids shape exactly (including
padding positions), so last-token pooling via attention_mask is correct.
Debug logging kept temporarily for verification.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
GenAI Generator doesn't accept attention_mask, so padded batches
produce contaminated hidden states. Fix: process each sentence
individually with only its real tokens, then take last-token pooling.

This should close the gap between HF (0.785) and GenAI (0.651) scores.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copilot AI review requested due to automatic review settings April 14, 2026 21:49
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Improves correctness and consistency of MTEB embedding evaluation across HF / ORT / GenAI backends by aligning device strings, pooling strategy, padding behavior, and embedding normalization.

Changes:

  • Map Olive gpu / gpu:<idx> device strings to PyTorch cuda / cuda:<idx> for SentenceTransformer initialization.
  • Switch ORT + GenAI wrappers from mean pooling to last-token pooling; avoid padding in GenAI by encoding each sequence using only real tokens.
  • Add L2 normalization to ORT embeddings to match SentenceTransformer’s Normalize module.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
olive/evaluator/olive_evaluator.py Normalizes Olive device strings to PyTorch-compatible cuda strings in HF evaluation path.
olive/evaluator/mteb_ort.py Adds L2 normalization, switches pooling to last-token, and removes padding from GenAI inference by per-sample processing.

Comment thread olive/evaluator/mteb_ort.py Outdated
Comment thread olive/evaluator/mteb_ort.py Outdated
Comment thread olive/evaluator/mteb_ort.py Outdated
Comment thread olive/evaluator/mteb_ort.py Outdated
natke and others added 4 commits April 14, 2026 15:41
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…-a-time

Group sequences with equal real token counts into a single Generator
call, reducing per-sample overhead while still avoiding padding
contamination.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Comment thread olive/evaluator/mteb_ort.py Fixed
Comment thread olive/evaluator/mteb_ort.py Fixed
Comment thread olive/evaluator/mteb_ort.py Fixed
natke and others added 2 commits April 15, 2026 09:42
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
natke and others added 2 commits April 15, 2026 12:42
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@jambayk jambayk merged commit 21fcca9 into main Apr 16, 2026
11 checks passed
@jambayk jambayk deleted the natke/mteb-device-fix branch April 16, 2026 20:18
xiaoyu-work pushed a commit that referenced this pull request Apr 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants