diff --git a/src/pipelines.js b/src/pipelines.js index 2b064d522..3b15af53d 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -67,6 +67,7 @@ import { Tensor, mean_pooling, interpolate, + quantize_embeddings, } from './utils/tensor.js'; import { RawImage } from './utils/image.js'; @@ -1112,6 +1113,8 @@ export class ZeroShotClassificationPipeline extends (/** @type {new (options: Te * @typedef {Object} FeatureExtractionPipelineOptions Parameters specific to feature extraction pipelines. * @property {'none'|'mean'|'cls'} [pooling="none"] The pooling method to use. * @property {boolean} [normalize=false] Whether or not to normalize the embeddings in the last dimension. + * @property {boolean} [quantize=false] Whether or not to quantize the embeddings. + * @property {'binary'|'ubinary'} [precision='binary'] The precision to use for quantization. * * @callback FeatureExtractionPipelineCallback Extract the features of the input(s). * @param {string|string[]} texts One or several texts (or one list of texts) to get the features of. @@ -1157,6 +1160,16 @@ export class ZeroShotClassificationPipeline extends (/** @type {new (options: Te * // dims: [1, 384] * // } * ``` + * **Example:** Calculating binary embeddings with `sentence-transformers` models. + * ```javascript + * const extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2'); + * const output = await extractor('This is a simple test.', { pooling: 'mean', quantize: true, precision: 'binary' }); + * // Tensor { + * // type: 'int8', + * // data: Int8Array [49, 108, 24, ...], + * // dims: [1, 48] + * // } + * ``` */ export class FeatureExtractionPipeline extends (/** @type {new (options: TextPipelineConstructorArgs) => FeatureExtractionPipelineType} */ (Pipeline)) { /** @@ -1171,6 +1184,8 @@ export class FeatureExtractionPipeline extends (/** @type {new (options: TextPip async _call(texts, { pooling = /** @type {'none'} */('none'), normalize = false, + quantize = false, + precision = /** @type {'binary'} */('binary'), } = {}) { // Run tokenization @@ -1203,6 +1218,10 @@ export class FeatureExtractionPipeline extends (/** @type {new (options: TextPip result = result.normalize(2, -1); } + if (quantize) { + result = quantize_embeddings(result, precision); + } + return result; } } diff --git a/src/tokenizers.js b/src/tokenizers.js index 5b58e37c0..e671c8318 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -2653,8 +2653,8 @@ export class PreTrainedTokenizer extends Callable { } } else { - if (text === null) { - throw Error('text may not be null') + if (text === null || text === undefined) { + throw Error('text may not be null or undefined') } if (Array.isArray(text_pair)) { diff --git a/src/utils/tensor.js b/src/utils/tensor.js index ccdf781be..469054cac 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -1193,3 +1193,47 @@ export function ones(size) { export function ones_like(tensor) { return ones(tensor.dims); } + +/** + * Quantizes the embeddings tensor to binary or unsigned binary precision. + * @param {Tensor} tensor The tensor to quantize. + * @param {'binary'|'ubinary'} precision The precision to use for quantization. + * @returns {Tensor} The quantized tensor. + */ +export function quantize_embeddings(tensor, precision) { + if (tensor.dims.length !== 2) { + throw new Error("The tensor must have 2 dimensions"); + } + if (tensor.dims.at(-1) % 8 !== 0) { + throw new Error("The last dimension of the tensor must be a multiple of 8"); + } + if (!['binary', 'ubinary'].includes(precision)) { + throw new Error("The precision must be either 'binary' or 'ubinary'"); + } + + const signed = precision === 'binary'; + const dtype = signed ? 'int8' : 'uint8'; + + // Create a typed array to store the packed bits + const cls = signed ? Int8Array : Uint8Array; + const inputData = tensor.data; + const outputData = new cls(inputData.length / 8); + + // Iterate over each number in the array + for (let i = 0; i < inputData.length; ++i) { + // Determine if the number is greater than 0 + const bit = inputData[i] > 0 ? 1 : 0; + + // Calculate the index in the typed array and the position within the byte + const arrayIndex = Math.floor(i / 8); + const bitPosition = i % 8; + + // Pack the bit into the typed array + outputData[arrayIndex] |= bit << (7 - bitPosition); + if (signed && bitPosition === 0) { + outputData[arrayIndex] -= 128; + } + }; + + return new Tensor(dtype, outputData, [tensor.dims[0], tensor.dims[1] / 8]); +}