From 3d35f66e8b724cdee5d8dc5ade9627e32345379c Mon Sep 17 00:00:00 2001 From: Dave <69651599+D4ve-R@users.noreply.github.com> Date: Fri, 23 Feb 2024 01:35:59 +0100 Subject: [PATCH 1/9] Add WavLMForXVector support --- src/models.js | 66 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/src/models.js b/src/models.js index 46c112a0b..19e183588 100644 --- a/src/models.js +++ b/src/models.js @@ -4735,6 +4735,47 @@ export class WavLMForSequenceClassification extends WavLMPreTrainedModel { } } +/** + * WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification. + * + * **Example:** Extract speaker embeddings with `WavLMForXVector`. + * ```javascript + * import { AutoProcessor, AutoModel, read_audio } from '@xenova/transformers'; + * + * const processor = await AutoProcessor.from_pretrained('D4ve-R/wavlm-base-plus-sv'); + * const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav'; + * const audio = await read_audio(url, 16000); + * const inputs = await processor(audio); + + * const model = await AutoModel.from_pretrained('D4ve-R/wavlm-base-plus-sv', {quantized: false}); + * const embeddings = await model(inputs); + * // { + * // embeddings: Tensor { + * // dims: [ 1, 512 ], + * // type: 'float32', + * // data: Float32Array(512) [-0.349443256855011, ...], + * // size: 512 + * // }, + * // logits: Tensor { + * // dims: [ 1, 512 ], + * // type: 'float32', + * // data: Float32Array(512) [0.022836603224277496, ...], + * // size: 512 + * // } + * // } + * ``` + */ +export class WavLMForXVector extends WavLMPreTrainedModel { + /** + * Calls the model on new inputs. + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} An object containing the model's output logits for sequence classification. + */ + async _call(model_inputs) { + return new XVectorOutput(await super._call(model_inputs)); + } +} + ////////////////////////////////////////////////// // SpeechT5 models /** @@ -5483,6 +5524,10 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([ ['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]], ]); +const MODEL_FOR_SPEAKER_VERIFICATION_MAPPING_NAMES = new Map([ + ['wavlm', ['WavLMForXVector', WavLMForXVector]], +]); + const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([ ['vitmatte', ['VitMatteForImageMatting', VitMatteForImageMatting]], ]); @@ -5523,6 +5568,7 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_SPEAKER_VERIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], ]; for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { @@ -5741,6 +5787,10 @@ export class AutoModelForAudioClassification extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES]; } +export class AutoModelForSpeakerVerification extends PretrainedMixin { + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SPEAKER_VERIFICATION_MAPPING_NAMES]; +} + export class AutoModelForDocumentQuestionAnswering extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES]; } @@ -5793,6 +5843,22 @@ export class SequenceClassifierOutput extends ModelOutput { } } +/** + * Base class for outputs of x-vector models. + */ +export class XVectorOutput extends ModelOutput { + /** + * @param {Object} output The output of the model. + * @param {Tensor} output.logits classification (or regression if config.num_labels==1) scores (before SoftMax). + * @param {Tensor} output.embeddings The embeddings of the input sequence. + */ + constructor({ logits, embeddings }) { + super(); + this.logits = logits; + this.embeddings = embeddings; + } +} + /** * Base class for outputs of token classification models. */ From 6d09c47e4c22171f6c6fc37c4a2a795586f6ddb7 Mon Sep 17 00:00:00 2001 From: Dave <69651599+D4ve-R@users.noreply.github.com> Date: Fri, 23 Feb 2024 10:52:29 +0100 Subject: [PATCH 2/9] fix model docs --- src/models.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models.js b/src/models.js index 19e183588..6c823bfe5 100644 --- a/src/models.js +++ b/src/models.js @@ -4747,7 +4747,7 @@ export class WavLMForSequenceClassification extends WavLMPreTrainedModel { * const audio = await read_audio(url, 16000); * const inputs = await processor(audio); - * const model = await AutoModel.from_pretrained('D4ve-R/wavlm-base-plus-sv', {quantized: false}); + * const model = await AutoModel.from_pretrained('D4ve-R/wavlm-base-plus-sv'); * const embeddings = await model(inputs); * // { * // embeddings: Tensor { From 3079d2c1a283959821723033251b9beba9a542e2 Mon Sep 17 00:00:00 2001 From: Dave <69651599+D4ve-R@users.noreply.github.com> Date: Fri, 23 Feb 2024 11:17:03 +0100 Subject: [PATCH 3/9] Add WavLMForAudioFrameClassification --- src/models.js | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/models.js b/src/models.js index 6c823bfe5..d6d5a2dff 100644 --- a/src/models.js +++ b/src/models.js @@ -4776,6 +4776,17 @@ export class WavLMForXVector extends WavLMPreTrainedModel { } } +export class WavLMForAudioFrameClassification extends WavLMPreTrainedModel { + /** + * Calls the model on new inputs. + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} An object containing the model's output logits for sequence classification. + */ + async _call(model_inputs) { + return new TokenClassifierOutput(await super._call(model_inputs)); + } +} + ////////////////////////////////////////////////// // SpeechT5 models /** @@ -5524,6 +5535,10 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([ ['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]], ]); +const MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = new Map([ + ['wavlm', ['WavLMForFrameClassification', WavLMForAudioFrameClassification]], +]); + const MODEL_FOR_SPEAKER_VERIFICATION_MAPPING_NAMES = new Map([ ['wavlm', ['WavLMForXVector', WavLMForXVector]], ]); @@ -5569,6 +5584,7 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_SPEAKER_VERIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], ]; for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { @@ -5787,6 +5803,10 @@ export class AutoModelForAudioClassification extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES]; } +export class AutoModelForAudioFrameClassification extends PretrainedMixin { + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES]; +} + export class AutoModelForSpeakerVerification extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SPEAKER_VERIFICATION_MAPPING_NAMES]; } From 750092ed50049aa896884929a7c2e959987b263b Mon Sep 17 00:00:00 2001 From: Dave <69651599+D4ve-R@users.noreply.github.com> Date: Fri, 23 Feb 2024 11:20:13 +0100 Subject: [PATCH 4/9] Add missing wWav2Vec2ForAudioFrameCl. --- src/models.js | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/models.js b/src/models.js index d6d5a2dff..c803b7513 100644 --- a/src/models.js +++ b/src/models.js @@ -4571,6 +4571,17 @@ export class Wav2Vec2ForSequenceClassification extends Wav2Vec2PreTrainedModel { return new SequenceClassifierOutput(await super._call(model_inputs)); } } + +export class Wav2Vec2ForAudioFrameClassification extends Wav2Vec2PreTrainedModel { + /** + * Calls the model on new inputs. + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} An object containing the model's output logits for sequence classification. + */ + async _call(model_inputs) { + return new TokenClassifierOutput(await super._call(model_inputs)); + } +} ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -5536,7 +5547,8 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([ ]); const MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = new Map([ - ['wavlm', ['WavLMForFrameClassification', WavLMForAudioFrameClassification]], + ['wavlm', ['WavLMForAudioFrameClassification', WavLMForAudioFrameClassification]], + ['wav2vec2', ['Wav2Vec2ForAudioFrameClassification', Wav2Vec2ForAudioFrameClassification]], ]); const MODEL_FOR_SPEAKER_VERIFICATION_MAPPING_NAMES = new Map([ From 627b48ce2433b022a496692d5c84e3dec2292afc Mon Sep 17 00:00:00 2001 From: Dave <69651599+D4ve-R@users.noreply.github.com> Date: Wed, 28 Feb 2024 22:40:17 +0100 Subject: [PATCH 5/9] Add doc comment --- src/models.js | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/models.js b/src/models.js index 362f1692d..3a395203c 100644 --- a/src/models.js +++ b/src/models.js @@ -4789,6 +4789,32 @@ export class WavLMForXVector extends WavLMPreTrainedModel { } } +/** + * WavLM Model with a frame classification head on top for tasks like Speaker Diarization. + * + * **Example:** Perform speaker diarization with `WavLMForAudioFrameClassification`. + * ```javascript + * import { AutoProcessor, AutoModelForAudioFrameClassification, read_audio } from '@xenova/transformers'; + * + * // Read and preprocess audio + * const processor = await AutoProcessor.from_pretrained('Xenova/wavlm-base-plus-sd'); + * const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav'; + * const audio = await read_audio(url, 16000); + * const inputs = await processor(audio); + * + * // Run model with inputs + * const model = await AutoModelForAudioFrameClassification.from_pretrained('Xenova/wavlm-base-plus-sd'); + * const outputs = await model(inputs); + * // { + * // logits: Tensor { + * // dims: [ 1, 549, 2 ], + * // type: 'float32', + * // data: Float32Array(1098) [0.5847219228744507, ...], + * // size: 1098 + * // } + * // } + * ``` + */ export class WavLMForAudioFrameClassification extends WavLMPreTrainedModel { /** * Calls the model on new inputs. From fc34610d9a72fb3b9c9644c6d15c8720d9a3c826 Mon Sep 17 00:00:00 2001 From: Dave <69651599+D4ve-R@users.noreply.github.com> Date: Wed, 28 Feb 2024 22:47:58 +0100 Subject: [PATCH 6/9] Add doc string wav2vec2 --- src/models.js | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/models.js b/src/models.js index 3a395203c..6549467a7 100644 --- a/src/models.js +++ b/src/models.js @@ -4572,6 +4572,9 @@ export class Wav2Vec2ForSequenceClassification extends Wav2Vec2PreTrainedModel { } } +/** + * Wav2Vec2 Model with a frame classification head on top for tasks like Speaker Diarization. + */ export class Wav2Vec2ForAudioFrameClassification extends Wav2Vec2PreTrainedModel { /** * Calls the model on new inputs. From 6221571abb81dbd23f2204b5d2f2552529389028 Mon Sep 17 00:00:00 2001 From: Dave <69651599+D4ve-R@users.noreply.github.com> Date: Mon, 4 Mar 2024 19:39:10 +0100 Subject: [PATCH 7/9] update comment --- src/models.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models.js b/src/models.js index 6549467a7..44122b782 100644 --- a/src/models.js +++ b/src/models.js @@ -4810,7 +4810,7 @@ export class WavLMForXVector extends WavLMPreTrainedModel { * const outputs = await model(inputs); * // { * // logits: Tensor { - * // dims: [ 1, 549, 2 ], + * // dims: [ 1, 549, 2 ], // [num_batches, num_frames, num_speakers] * // type: 'float32', * // data: Float32Array(1098) [0.5847219228744507, ...], * // size: 1098 From 3b0f1acf640d0d9686db78b3c830f73355b12fd2 Mon Sep 17 00:00:00 2001 From: Dave <69651599+D4ve-R@users.noreply.github.com> Date: Tue, 5 Mar 2024 12:47:10 +0100 Subject: [PATCH 8/9] make example like python --- src/models.js | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/models.js b/src/models.js index 44122b782..7a560e3cf 100644 --- a/src/models.js +++ b/src/models.js @@ -4807,7 +4807,7 @@ export class WavLMForXVector extends WavLMPreTrainedModel { * * // Run model with inputs * const model = await AutoModelForAudioFrameClassification.from_pretrained('Xenova/wavlm-base-plus-sd'); - * const outputs = await model(inputs); + * const { logits } = await model(inputs); * // { * // logits: Tensor { * // dims: [ 1, 549, 2 ], // [num_batches, num_frames, num_speakers] @@ -4816,6 +4816,13 @@ export class WavLMForXVector extends WavLMPreTrainedModel { * // size: 1098 * // } * // } + * // labels is a one-hot array of shape (num_batches, num_frames, num_speakers) + * const labels = logits + * .sigmoid() + * .toList() + * .map(frames => frames.map( + * frame => frame.map(speaker => speaker > 0.5 ? 1 : 0); + * )); * ``` */ export class WavLMForAudioFrameClassification extends WavLMPreTrainedModel { From 008a079f2790fac91bb962ec139750c8bcb835c2 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 7 Mar 2024 14:52:16 +0200 Subject: [PATCH 9/9] Update src/models.js --- src/models.js | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/models.js b/src/models.js index 7a560e3cf..342f3d74f 100644 --- a/src/models.js +++ b/src/models.js @@ -4810,19 +4810,23 @@ export class WavLMForXVector extends WavLMPreTrainedModel { * const { logits } = await model(inputs); * // { * // logits: Tensor { - * // dims: [ 1, 549, 2 ], // [num_batches, num_frames, num_speakers] + * // dims: [ 1, 549, 2 ], // [batch_size, num_frames, num_speakers] * // type: 'float32', - * // data: Float32Array(1098) [0.5847219228744507, ...], + * // data: Float32Array(1098) [-3.5301010608673096, ...], * // size: 1098 * // } * // } - * // labels is a one-hot array of shape (num_batches, num_frames, num_speakers) - * const labels = logits - * .sigmoid() - * .toList() - * .map(frames => frames.map( - * frame => frame.map(speaker => speaker > 0.5 ? 1 : 0); - * )); + * + * const labels = logits[0].sigmoid().tolist().map( + * frames => frames.map(speaker => speaker > 0.5 ? 1 : 0) + * ); + * console.log(labels); // labels is a one-hot array of shape (num_frames, num_speakers) + * // [ + * // [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], + * // [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0, 0], + * // [0, 0], [0, 1], [0, 1], [0, 1], [0, 1], [0, 1], + * // ... + * // ] * ``` */ export class WavLMForAudioFrameClassification extends WavLMPreTrainedModel {