diff --git a/scripts/convert.py b/scripts/convert.py index 87ac56f99..5b8620471 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -99,6 +99,10 @@ 'per_channel': False, 'reduce_range': False, }, + 'wavlm': { + 'per_channel': False, + 'reduce_range': False, + }, } MODELS_WITHOUT_TOKENIZERS = [ diff --git a/scripts/supported_models.py b/scripts/supported_models.py index 99f3421ab..7d7a5c169 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -1019,6 +1019,12 @@ 'microsoft/wavlm-base-plus', 'microsoft/wavlm-large', ], + + # Audio XVector (e.g., for speaker verification) + 'audio-xvector': [ + 'microsoft/wavlm-base-plus-sv', + 'microsoft/wavlm-base-sv', + ], }, 'whisper': { # Automatic speech recognition diff --git a/src/models.js b/src/models.js index 46c112a0b..cf7807300 100644 --- a/src/models.js +++ b/src/models.js @@ -4735,6 +4735,49 @@ export class WavLMForSequenceClassification extends WavLMPreTrainedModel { } } +/** + * WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification. + * + * **Example:** Extract speaker embeddings with `WavLMForXVector`. + * ```javascript + * import { AutoProcessor, AutoModel, read_audio } from '@xenova/transformers'; + * + * // Read and preprocess audio + * const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sv'); + * const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav'; + * const audio = await read_audio(url, 16000); + * const inputs = await processor(audio); + * + * // Run model with inputs + * const model = await AutoModel.from_pretrained('Xenova/wavlm-base-plus-sv'); + * const outputs = await model(inputs); + * // { + * // logits: Tensor { + * // dims: [ 1, 512 ], + * // type: 'float32', + * // data: Float32Array(512) [0.5847219228744507, ...], + * // size: 512 + * // }, + * // embeddings: Tensor { + * // dims: [ 1, 512 ], + * // type: 'float32', + * // data: Float32Array(512) [-0.09079201519489288, ...], + * // size: 512 + * // } + * // } + * ``` + */ +export class WavLMForXVector extends WavLMPreTrainedModel { + /** + * Calls the model on new inputs. + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} An object containing the model's output logits and speaker embeddings. + */ + async _call(model_inputs) { + return new XVectorOutput(await super._call(model_inputs)); + } +} + ////////////////////////////////////////////////// // SpeechT5 models /** @@ -5483,6 +5526,10 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([ ['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]], ]); +const MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = new Map([ + ['wavlm', ['WavLMForXVector', WavLMForXVector]], +]); + const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([ ['vitmatte', ['VitMatteForImageMatting', VitMatteForImageMatting]], ]); @@ -5523,6 +5570,7 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], ]; for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { @@ -5741,6 +5789,10 @@ export class AutoModelForAudioClassification extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES]; } +export class AutoModelForXVector extends PretrainedMixin { + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES]; +} + export class AutoModelForDocumentQuestionAnswering extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES]; } @@ -5793,6 +5845,22 @@ export class SequenceClassifierOutput extends ModelOutput { } } +/** + * Base class for outputs of XVector models. + */ +export class XVectorOutput extends ModelOutput { + /** + * @param {Object} output The output of the model. + * @param {Tensor} output.logits Classification hidden states before AMSoftmax, of shape `(batch_size, config.xvector_output_dim)`. + * @param {Tensor} output.embeddings Utterance embeddings used for vector similarity-based retrieval, of shape `(batch_size, config.xvector_output_dim)`. + */ + constructor({ logits, embeddings }) { + super(); + this.logits = logits; + this.embeddings = embeddings; + } +} + /** * Base class for outputs of token classification models. */ diff --git a/src/processors.js b/src/processors.js index 463f82551..a8b82e913 100644 --- a/src/processors.js +++ b/src/processors.js @@ -158,7 +158,7 @@ function post_process_object_detection(outputs, threshold = 0.5, target_sizes = function validate_audio_inputs(audio, feature_extractor) { if (!(audio instanceof Float32Array || audio instanceof Float64Array)) { throw new Error( - `${feature_extractor} expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead.` + + `${feature_extractor} expects input to be a Float32Array or a Float64Array, but got ${audio?.constructor?.name ?? typeof audio} instead. ` + `If using the feature extractor directly, remember to use \`read_audio(url, sampling_rate)\` to obtain the raw audio data of the file/url.` ) }