Skip to content

EmbeddingGemma usage #1418

@MithrilMan

Description

@MithrilMan

Question

I'm new to transformers.js
I want to use embeddinggemma into my web app and I've looked at the example on its usage at this link:
https://huggingface.co/blog/embeddinggemma#transformersjs

At the same time I've seen a different code, using pipeline, regarding embeddings:
https://huggingface.co/docs/transformers.js/api/pipelines#pipelinesfeatureextractionpipeline

I'm trying to create a custom pipeline and in typescript I'm building the pipeline like

class EmbeddingPipeline {
    private static instance: Promise<FeatureExtractionPipeline> | null = null;
    private static model = 'onnx-community/embeddinggemma-300m-ONNX';
    private static readonly task = 'feature-extraction';

    // Device rilevato (default wasm)
    private static device: 'webgpu' | 'wasm' = 'wasm';
    private static deviceInitPromise: Promise<void> | null = null;

    private static async detectDeviceOnce(): Promise<void> {
        if (this.deviceInitPromise) return this.deviceInitPromise;
        this.deviceInitPromise = (async () => {
            if (typeof navigator !== 'undefined' && 'gpu' in navigator) {
                try {
                    const adapter = await (navigator as any).gpu.requestAdapter();
                    if (adapter) {
                        this.device = 'webgpu';
                        return;
                    }
                } catch {
                    // ignore, fallback to wasm
                }
            }
            this.device = 'wasm';
        })();
        return this.deviceInitPromise;
    }

    static getSelectedDevice(): 'webgpu' | 'wasm' {
        return this.device;
    }

    static async getInstance(progress_callback?: ProgressCallback): Promise<FeatureExtractionPipeline> {
        if (this.instance) return this.instance;

        // Rileva device una sola volta
        await this.detectDeviceOnce();

        const build = async (device: 'webgpu' | 'wasm') =>
            pipeline(
                this.task,
                this.model,
                {
                    progress_callback,
                    dtype: 'q8',
                    device
                }
            ) as Promise<FeatureExtractionPipeline>;

        this.instance = (async (): Promise<FeatureExtractionPipeline> => {
            try {
                return await build(this.device);
            } catch (e) {
                if (this.device === 'webgpu') {
                    // Fallback automatico a wasm
                    this.device = 'wasm';
                    return await build('wasm');
                }
                throw e;
            }
        })();

        return this.instance;
    }
}


const getEmbeddingDevice = () => EmbeddingPipeline.getSelectedDevice();
const embedding_prefixes_per_task: Record<EmbeddingTask, string> = {
    'query': "task: search result | query: ",
    'document': "title: none | text: ",
};

export type EmbeddingTask = 'query' | 'document';

export const getEmbedding = async (task: EmbeddingTask, text: string): Promise<Float32Array> => {
    const extractor = await EmbeddingPipeline.getInstance();

    const prefix = embedding_prefixes_per_task[task];
    const result = await extractor(`${prefix}${text}`, { pooling: 'mean', normalize: true });

    return result.data as Float32Array;
};

I'm using the same sentences (with prefixes) used by your example (I'm running both my class and your code to be sure if they matches) and the embedding result is different.

What am I doing wrong? Do you have any reference to some proper docs reference that explain properly how this works?

Thanks

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions