From fcc5ba37f225269414dc6462eb7f599c17668d7e Mon Sep 17 00:00:00 2001 From: Joshua Lochner <26504141+xenova@users.noreply.github.com> Date: Wed, 19 Nov 2025 00:04:57 -0500 Subject: [PATCH 1/6] Add support for FixedLength pre-tokenizer --- src/tokenizers.js | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/tokenizers.js b/src/tokenizers.js index 33cdb3e65..156df8d4f 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -1392,6 +1392,8 @@ class PreTokenizer extends Callable { return new DigitsPreTokenizer(config); case 'Replace': return new ReplacePreTokenizer(config); + case 'FixedLength': + return new FixedLengthPreTokenizer(config); default: throw new Error(`Unknown PreTokenizer type: ${config.type}`); } @@ -2510,6 +2512,31 @@ class ReplacePreTokenizer extends PreTokenizer { } } +class FixedLengthPreTokenizer extends PreTokenizer { + /** + * @param {Object} config The configuration options for the pre-tokenizer. + * @param {number} config.length The fixed length to split the text into. + */ + constructor(config) { + super(); + this._length = config.length; + } + + /** + * Pre-tokenizes the input text by splitting it into fixed-length tokens. + * @param {string} text The text to be pre-tokenized. + * @param {Object} [options] Additional options for the pre-tokenization logic. + * @returns {string[]} An array of tokens produced by splitting the input text into fixed-length tokens. + */ + pre_tokenize_text(text, options) { + const tokens = []; + for (let i = 0; i < text.length; i += this._length) { + tokens.push(text.slice(i, i + this._length)); + } + return tokens; + } +} + const SPECIAL_TOKEN_ATTRIBUTES = [ 'bos_token', 'eos_token', From 7e5422c6aa0d70af760e31f48bef118e8ad0c037 Mon Sep 17 00:00:00 2001 From: Joshua Lochner <26504141+xenova@users.noreply.github.com> Date: Wed, 19 Nov 2025 00:05:04 -0500 Subject: [PATCH 2/6] Create randn tensor function --- src/utils/tensor.js | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/utils/tensor.js b/src/utils/tensor.js index ea822b6c6..0022a4f72 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -846,7 +846,11 @@ export class Tensor { map_fn = Number; } else if (!is_source_bigint && is_dest_bigint) { // TypeError: Cannot convert [x] to a BigInt - map_fn = BigInt; + if (['float16', 'float32', 'float64'].includes(this.type)) { + map_fn = (x) => BigInt(Math.floor(x)); + } else { + map_fn = BigInt; + } } // @ts-ignore @@ -1525,6 +1529,29 @@ export function rand(size) { ) } +/** + * Returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1 (also called the standard normal distribution). + * @param {number[]} size A sequence of integers defining the shape of the output tensor. + * @returns {Tensor} The random tensor. + */ +export function randn(size) { + const length = size.reduce((a, b) => a * b, 1); + + // Box-Muller transform + function boxMullerRandom() { + // NOTE: 1 - Math.random() is used to avoid log(0) + const u = 1 - Math.random(); + const v = 1 - Math.random(); + return Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v); + } + + return new Tensor( + "float32", + Float32Array.from({ length }, () => boxMullerRandom()), + size, + ) +} + /** * Quantizes the embeddings tensor to binary or unsigned binary precision. * @param {Tensor} tensor The tensor to quantize. From 34c07a603731337fcede1987c2695e13d8c00fef Mon Sep 17 00:00:00 2001 From: Joshua Lochner <26504141+xenova@users.noreply.github.com> Date: Wed, 19 Nov 2025 00:05:28 -0500 Subject: [PATCH 3/6] Add support for Supertonic TTS models --- src/models.js | 65 +++++++++++++++++++++++++++++++++++++++ src/pipelines.js | 80 ++++++++++++++++++++++++++++++++++++------------ 2 files changed, 125 insertions(+), 20 deletions(-) diff --git a/src/models.js b/src/models.js index 7d31e2ed3..69d81bcf2 100644 --- a/src/models.js +++ b/src/models.js @@ -109,6 +109,7 @@ import { std_mean, Tensor, DataTypeMap, + randn, } from './utils/tensor.js'; import { RawImage } from './utils/image.js'; @@ -136,6 +137,7 @@ const MODEL_TYPES = { AudioTextToText: 10, AutoEncoder: 11, ImageAudioTextToText: 12, + Supertonic: 13, } ////////////////////////////////////////////////// @@ -1262,6 +1264,14 @@ export class PreTrainedModel extends Callable { decoder_model: 'decoder_model', }, options), ]); + } else if (modelType === MODEL_TYPES.Supertonic) { + info = await Promise.all([ + constructSessions(pretrained_model_name_or_path, { + text_encoder: 'text_encoder', + latent_denoiser: 'latent_denoiser', + voice_decoder: 'voice_decoder', + }, options), + ]); } else { // should be MODEL_TYPES.EncoderOnly if (modelType !== MODEL_TYPES.EncoderOnly) { const type = modelName ?? config?.model_type; @@ -6861,6 +6871,59 @@ export class SpeechT5HifiGan extends PreTrainedModel { } ////////////////////////////////////////////////// +export class SupertonicPreTrainedModel extends PreTrainedModel { } +export class SupertonicForConditionalGeneration extends SupertonicPreTrainedModel { + + async generate_speech({ + // Required inputs + input_ids, + attention_mask, + style, + + // Optional inputs + num_inference_steps = 5, + }) { + // @ts-expect-error TS2339 + const { sampling_rate, chunk_compress_factor, base_chunk_size, latent_dim } = this.config; + + // 1. Text Encoder + const { last_hidden_state, durations } = await sessionRun(this.sessions['text_encoder'], { + input_ids, attention_mask, style, + }); + + // 2. Latent Denoiser + const wav_len_max = durations.max().item() * sampling_rate; + const chunk_size = base_chunk_size * chunk_compress_factor; + const latent_len = Math.floor((wav_len_max + chunk_size - 1) / chunk_size); + const batch_size = input_ids.dims[0]; + const latent_mask = ones([batch_size, latent_len]); + const num_steps = full([batch_size], num_inference_steps); + + let noisy_latents = randn([batch_size, latent_dim * chunk_compress_factor, latent_len]); + for (let step = 0; step < num_inference_steps; ++step) { + const timestep = full([batch_size], step); + ({ denoised_latents: noisy_latents } = await sessionRun(this.sessions['latent_denoiser'], { + style, + noisy_latents, + latent_mask, + encoder_outputs: last_hidden_state, + attention_mask, + timestep, + num_inference_steps: num_steps, + })); + } + + // 3. Voice Decoder + const { waveform } = await sessionRun(this.sessions['voice_decoder'], { + latents: noisy_latents, + }); + return { + waveform, + durations, + } + } +} + ////////////////////////////////////////////////// // TrOCR models @@ -8000,6 +8063,7 @@ const MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = new Map([ const MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = new Map([ ['vits', ['VitsModel', VitsModel]], ['musicgen', ['MusicgenForConditionalGeneration', MusicgenForConditionalGeneration]], + ['supertonic', ['SupertonicForConditionalGeneration', SupertonicForConditionalGeneration]], ]); const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([ @@ -8386,6 +8450,7 @@ const CUSTOM_MAPPING = [ ['SnacDecoderModel', SnacDecoderModel, MODEL_TYPES.EncoderOnly], ['Gemma3nForConditionalGeneration', Gemma3nForConditionalGeneration, MODEL_TYPES.ImageAudioTextToText], + ['SupertonicForConditionalGeneration', SupertonicForConditionalGeneration, MODEL_TYPES.Supertonic], ] for (const [name, model, type] of CUSTOM_MAPPING) { MODEL_TYPE_MAPPING.set(name, type); diff --git a/src/pipelines.js b/src/pipelines.js index 6c84403e2..0c42aa7ff 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -2792,6 +2792,8 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options: * * @typedef {Object} TextToAudioPipelineOptions Parameters specific to text-to-audio pipelines. * @property {Tensor|Float32Array|string|URL} [speaker_embeddings=null] The speaker embeddings (if the model requires it). + * @property {number} [num_inference_steps=5] The number of denoising steps (if the model supports it). + * More denoising steps usually lead to higher quality audio but slower inference. * * @callback TextToAudioPipelineCallback Generates speech/audio from the inputs. * @param {string|string[]} texts The text(s) to generate. @@ -2850,20 +2852,74 @@ export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPi this.vocoder = options.vocoder ?? null; } + async _prepare_speaker_embeddings(speaker_embeddings) { + // Load speaker embeddings as Float32Array from path/URL + if (typeof speaker_embeddings === 'string' || speaker_embeddings instanceof URL) { + // Load from URL with fetch + speaker_embeddings = new Float32Array( + await (await fetch(speaker_embeddings)).arrayBuffer() + ); + } + + if (speaker_embeddings instanceof Float32Array) { + speaker_embeddings = new Tensor( + 'float32', + speaker_embeddings, + [speaker_embeddings.length] + ) + } else if (!(speaker_embeddings instanceof Tensor)) { + throw new Error("Speaker embeddings must be a `Tensor`, `Float32Array`, `string`, or `URL`.") + } + + return speaker_embeddings; + } /** @type {TextToAudioPipelineCallback} */ async _call(text_inputs, { speaker_embeddings = null, + num_inference_steps = 5, } = {}) { // If this.processor is not set, we are using a `AutoModelForTextToWaveform` model if (this.processor) { return this._call_text_to_spectrogram(text_inputs, { speaker_embeddings }); + } else if ( + this.model.config.model_type === "supertonic" + ) { + return this._call_supertonic(text_inputs, { speaker_embeddings, num_inference_steps }); } else { return this._call_text_to_waveform(text_inputs); } } + async _call_supertonic(text_inputs, { speaker_embeddings, num_inference_steps }) { + if (!speaker_embeddings) { + throw new Error("Speaker embeddings must be provided for Supertonic models."); + } + speaker_embeddings = await this._prepare_speaker_embeddings(speaker_embeddings); + + // @ts-expect-error TS2339 + const { sampling_rate, style_dim } = this.model.config; + + speaker_embeddings = (/** @type {Tensor} */ (speaker_embeddings)).view(1, -1, style_dim); + const inputs = this.tokenizer(text_inputs, { + padding: true, + truncation: true, + }); + + // @ts-expect-error TS2339 + const { waveform } = await this.model.generate_speech({ + ...inputs, + style: speaker_embeddings, + num_inference_steps, + }); + + return new RawAudio( + waveform.data, + sampling_rate, + ) + } + async _call_text_to_waveform(text_inputs) { // Run tokenization @@ -2891,32 +2947,16 @@ export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPi this.vocoder = await AutoModel.from_pretrained(this.DEFAULT_VOCODER_ID, { dtype: 'fp32' }); } - // Load speaker embeddings as Float32Array from path/URL - if (typeof speaker_embeddings === 'string' || speaker_embeddings instanceof URL) { - // Load from URL with fetch - speaker_embeddings = new Float32Array( - await (await fetch(speaker_embeddings)).arrayBuffer() - ); - } - - if (speaker_embeddings instanceof Float32Array) { - speaker_embeddings = new Tensor( - 'float32', - speaker_embeddings, - [1, speaker_embeddings.length] - ) - } else if (!(speaker_embeddings instanceof Tensor)) { - throw new Error("Speaker embeddings must be a `Tensor`, `Float32Array`, `string`, or `URL`.") - } - // Run tokenization const { input_ids } = this.tokenizer(text_inputs, { padding: true, truncation: true, }); - // NOTE: At this point, we are guaranteed that `speaker_embeddings` is a `Tensor` - // @ts-ignore + speaker_embeddings = await this._prepare_speaker_embeddings(speaker_embeddings); + speaker_embeddings = speaker_embeddings.view(1, -1); + + // @ts-expect-error TS2339 const { waveform } = await this.model.generate_speech(input_ids, speaker_embeddings, { vocoder: this.vocoder }); const sampling_rate = this.processor.feature_extractor.config.sampling_rate; From 6b48f334bc1b4e35b3cbeaf3af56c2f251e195d5 Mon Sep 17 00:00:00 2001 From: Joshua Lochner <26504141+xenova@users.noreply.github.com> Date: Wed, 19 Nov 2025 09:48:42 -0500 Subject: [PATCH 4/6] Update TTS JSDoc --- src/pipelines.js | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/src/pipelines.js b/src/pipelines.js index 0c42aa7ff..5b38b7246 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -2807,31 +2807,24 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options: * Text-to-audio generation pipeline using any `AutoModelForTextToWaveform` or `AutoModelForTextToSpectrogram`. * This pipeline generates an audio file from an input text and optional other conditional inputs. * - * **Example:** Generate audio from text with `Xenova/speecht5_tts`. + * **Example:** Generate audio from text with `onnx-community/Supertonic-TTS-ONNX`. * ```javascript - * const synthesizer = await pipeline('text-to-speech', 'Xenova/speecht5_tts', { quantized: false }); - * const speaker_embeddings = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin'; - * const out = await synthesizer('Hello, my dog is cute', { speaker_embeddings }); + * const synthesizer = await pipeline('text-to-speech', 'onnx-community/Supertonic-TTS-ONNX'); + * const speaker_embeddings = 'https://huggingface.co/onnx-community/Supertonic-TTS-ONNX/resolve/main/voices/F1.bin'; + * const output = await synthesizer('Hello there, how are you doing?', { speaker_embeddings }); * // RawAudio { - * // audio: Float32Array(26112) [-0.00005657337896991521, 0.00020583874720614403, ...], - * // sampling_rate: 16000 + * // audio: Float32Array(101376) [-0.00006606941315112635, -0.00006774164648959413, ...], + * // sampling_rate: 44100 * // } - * ``` - * - * You can then save the audio to a .wav file with the `wavefile` package: - * ```javascript - * import wavefile from 'wavefile'; - * import fs from 'fs'; - * - * const wav = new wavefile.WaveFile(); - * wav.fromScratch(1, out.sampling_rate, '32f', out.audio); - * fs.writeFileSync('out.wav', wav.toBuffer()); + * + * // Optional: Save the audio to a .wav file or Blob + * await output.save('output.wav'); // You can also use `output.toBlob()` to access the audio as a Blob * ``` * * **Example:** Multilingual speech generation with `Xenova/mms-tts-fra`. See [here](https://huggingface.co/models?pipeline_tag=text-to-speech&other=vits&sort=trending) for the full list of available languages (1107). * ```javascript * const synthesizer = await pipeline('text-to-speech', 'Xenova/mms-tts-fra'); - * const out = await synthesizer('Bonjour'); + * const output = await synthesizer('Bonjour'); * // RawAudio { * // audio: Float32Array(23808) [-0.00037693005288019776, 0.0003325853613205254, ...], * // sampling_rate: 16000 From 8349ff99dbd621fd4033dca603d37b932622e153 Mon Sep 17 00:00:00 2001 From: Joshua Lochner <26504141+xenova@users.noreply.github.com> Date: Wed, 19 Nov 2025 09:55:21 -0500 Subject: [PATCH 5/6] Add supertonic speed parameter --- src/models.js | 2 ++ src/pipelines.js | 13 ++++++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/models.js b/src/models.js index 69d81bcf2..e6f251c92 100644 --- a/src/models.js +++ b/src/models.js @@ -6882,6 +6882,7 @@ export class SupertonicForConditionalGeneration extends SupertonicPreTrainedMode // Optional inputs num_inference_steps = 5, + speed = 1.05, }) { // @ts-expect-error TS2339 const { sampling_rate, chunk_compress_factor, base_chunk_size, latent_dim } = this.config; @@ -6890,6 +6891,7 @@ export class SupertonicForConditionalGeneration extends SupertonicPreTrainedMode const { last_hidden_state, durations } = await sessionRun(this.sessions['text_encoder'], { input_ids, attention_mask, style, }); + durations.div_(speed); // Apply speed factor to duration // 2. Latent Denoiser const wav_len_max = durations.max().item() * sampling_rate; diff --git a/src/pipelines.js b/src/pipelines.js index 5b38b7246..1032a7dea 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -2792,8 +2792,9 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options: * * @typedef {Object} TextToAudioPipelineOptions Parameters specific to text-to-audio pipelines. * @property {Tensor|Float32Array|string|URL} [speaker_embeddings=null] The speaker embeddings (if the model requires it). - * @property {number} [num_inference_steps=5] The number of denoising steps (if the model supports it). + * @property {number} [num_inference_steps] The number of denoising steps (if the model supports it). * More denoising steps usually lead to higher quality audio but slower inference. + * @property {number} [speed] The speed of the generated audio (if the model supports it). * * @callback TextToAudioPipelineCallback Generates speech/audio from the inputs. * @param {string|string[]} texts The text(s) to generate. @@ -2813,7 +2814,7 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options: * const speaker_embeddings = 'https://huggingface.co/onnx-community/Supertonic-TTS-ONNX/resolve/main/voices/F1.bin'; * const output = await synthesizer('Hello there, how are you doing?', { speaker_embeddings }); * // RawAudio { - * // audio: Float32Array(101376) [-0.00006606941315112635, -0.00006774164648959413, ...], + * // audio: Float32Array(95232) [-0.000482565927086398, -0.0004853440332226455, ...], * // sampling_rate: 44100 * // } * @@ -2870,7 +2871,8 @@ export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPi /** @type {TextToAudioPipelineCallback} */ async _call(text_inputs, { speaker_embeddings = null, - num_inference_steps = 5, + num_inference_steps, + speed, } = {}) { // If this.processor is not set, we are using a `AutoModelForTextToWaveform` model @@ -2879,13 +2881,13 @@ export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPi } else if ( this.model.config.model_type === "supertonic" ) { - return this._call_supertonic(text_inputs, { speaker_embeddings, num_inference_steps }); + return this._call_supertonic(text_inputs, { speaker_embeddings, num_inference_steps, speed }); } else { return this._call_text_to_waveform(text_inputs); } } - async _call_supertonic(text_inputs, { speaker_embeddings, num_inference_steps }) { + async _call_supertonic(text_inputs, { speaker_embeddings, num_inference_steps, speed }) { if (!speaker_embeddings) { throw new Error("Speaker embeddings must be provided for Supertonic models."); } @@ -2905,6 +2907,7 @@ export class TextToAudioPipeline extends (/** @type {new (options: TextToAudioPi ...inputs, style: speaker_embeddings, num_inference_steps, + speed, }); return new RawAudio( From 323018f745b13b9fa55597abfd49d438af3c852a Mon Sep 17 00:00:00 2001 From: Joshua Lochner <26504141+xenova@users.noreply.github.com> Date: Wed, 19 Nov 2025 10:07:19 -0500 Subject: [PATCH 6/6] Update list of supported models --- README.md | 3 ++- docs/snippets/6_supported-models.snippet | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index cd84d99fd..21ac7da27 100644 --- a/README.md +++ b/README.md @@ -432,7 +432,8 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://huggingface.co/papers/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm)** (from Stability AI) released with the paper [StableLM 3B 4E1T (Technical Report)](https://stability.wandb.io/stability-llm/stable-lm/reports/StableLM-3B-4E1T--VmlldzoyMjU4?accessToken=u3zujipenkx5g7rtcj9qojjgxpconyjktjkli2po09nffrffdhhchq045vp0wyfo) by Jonathan Tow, Marco Bellagente, Dakota Mahan, Carlos Riquelme Ruiz, Duy Phung, Maksym Zhuravinskyi, Nathan Cooper, Nikhil Pinnaparaju, Reshinth Adithyan, and James Baicoianu. 1. **[Starcoder2](https://huggingface.co/docs/transformers/main/model_doc/starcoder2)** (from BigCode team) released with the paper [StarCoder 2 and The Stack v2: The Next Generation](https://huggingface.co/papers/2402.19173) by Anton Lozhkov, Raymond Li, Loubna Ben Allal, Federico Cassano, Joel Lamy-Poirier, Nouamane Tazi, Ao Tang, Dmytro Pykhtar, Jiawei Liu, Yuxiang Wei, Tianyang Liu, Max Tian, Denis Kocetkov, Arthur Zucker, Younes Belkada, Zijian Wang, Qian Liu, Dmitry Abulkhanov, Indraneil Paul, Zhuang Li, Wen-Ding Li, Megan Risdal, Jia Li, Jian Zhu, Terry Yue Zhuo, Evgenii Zheltonozhskii, Nii Osae Osae Dade, Wenhao Yu, Lucas Krauß, Naman Jain, Yixuan Su, Xuanli He, Manan Dey, Edoardo Abati, Yekun Chai, Niklas Muennighoff, Xiangru Tang, Muhtasham Oblokulov, Christopher Akiki, Marc Marone, Chenghao Mou, Mayank Mishra, Alex Gu, Binyuan Hui, Tri Dao, Armel Zebaze, Olivier Dehaene, Nicolas Patry, Canwen Xu, Julian McAuley, Han Hu, Torsten Scholak, Sebastien Paquet, Jennifer Robinson, Carolyn Jane Anderson, Nicolas Chapados, Mostofa Patwary, Nima Tajbakhsh, Yacine Jernite, Carlos Muñoz Ferrandis, Lingming Zhang, Sean Hughes, Thomas Wolf, Arjun Guha, Leandro von Werra, and Harm de Vries. -1. StyleTTS 2 (from Columbia University) released with the paper [StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models](https://huggingface.co/papers/2306.07691) by Yinghao Aaron Li, Cong Han, Vinay S. Raghavan, Gavin Mischler, Nima Mesgarani. +1. **StyleTTS 2** (from Columbia University) released with the paper [StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models](https://huggingface.co/papers/2306.07691) by Yinghao Aaron Li, Cong Han, Vinay S. Raghavan, Gavin Mischler, Nima Mesgarani. +1. **Supertonic** (from Supertone) released with the paper [Training Flow Matching Models with Reliable Labels via Self-Purification](https://huggingface.co/papers/2509.19091) by Hyeongju Kim, Yechan Yu, June Young Yi, Juheon Lee. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://huggingface.co/papers/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin2SR](https://huggingface.co/docs/transformers/model_doc/swin2sr)** (from University of Würzburg) released with the paper [Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration](https://huggingface.co/papers/2209.11345) by Marcos V. Conde, Ui-Jin Choi, Maxime Burchi, Radu Timofte. 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://huggingface.co/papers/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index 728f39803..084212549 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -146,7 +146,8 @@ 1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://huggingface.co/papers/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. 1. **[StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm)** (from Stability AI) released with the paper [StableLM 3B 4E1T (Technical Report)](https://stability.wandb.io/stability-llm/stable-lm/reports/StableLM-3B-4E1T--VmlldzoyMjU4?accessToken=u3zujipenkx5g7rtcj9qojjgxpconyjktjkli2po09nffrffdhhchq045vp0wyfo) by Jonathan Tow, Marco Bellagente, Dakota Mahan, Carlos Riquelme Ruiz, Duy Phung, Maksym Zhuravinskyi, Nathan Cooper, Nikhil Pinnaparaju, Reshinth Adithyan, and James Baicoianu. 1. **[Starcoder2](https://huggingface.co/docs/transformers/main/model_doc/starcoder2)** (from BigCode team) released with the paper [StarCoder 2 and The Stack v2: The Next Generation](https://huggingface.co/papers/2402.19173) by Anton Lozhkov, Raymond Li, Loubna Ben Allal, Federico Cassano, Joel Lamy-Poirier, Nouamane Tazi, Ao Tang, Dmytro Pykhtar, Jiawei Liu, Yuxiang Wei, Tianyang Liu, Max Tian, Denis Kocetkov, Arthur Zucker, Younes Belkada, Zijian Wang, Qian Liu, Dmitry Abulkhanov, Indraneil Paul, Zhuang Li, Wen-Ding Li, Megan Risdal, Jia Li, Jian Zhu, Terry Yue Zhuo, Evgenii Zheltonozhskii, Nii Osae Osae Dade, Wenhao Yu, Lucas Krauß, Naman Jain, Yixuan Su, Xuanli He, Manan Dey, Edoardo Abati, Yekun Chai, Niklas Muennighoff, Xiangru Tang, Muhtasham Oblokulov, Christopher Akiki, Marc Marone, Chenghao Mou, Mayank Mishra, Alex Gu, Binyuan Hui, Tri Dao, Armel Zebaze, Olivier Dehaene, Nicolas Patry, Canwen Xu, Julian McAuley, Han Hu, Torsten Scholak, Sebastien Paquet, Jennifer Robinson, Carolyn Jane Anderson, Nicolas Chapados, Mostofa Patwary, Nima Tajbakhsh, Yacine Jernite, Carlos Muñoz Ferrandis, Lingming Zhang, Sean Hughes, Thomas Wolf, Arjun Guha, Leandro von Werra, and Harm de Vries. -1. StyleTTS 2 (from Columbia University) released with the paper [StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models](https://huggingface.co/papers/2306.07691) by Yinghao Aaron Li, Cong Han, Vinay S. Raghavan, Gavin Mischler, Nima Mesgarani. +1. **StyleTTS 2** (from Columbia University) released with the paper [StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models](https://huggingface.co/papers/2306.07691) by Yinghao Aaron Li, Cong Han, Vinay S. Raghavan, Gavin Mischler, Nima Mesgarani. +1. **Supertonic** (from Supertone) released with the paper [Training Flow Matching Models with Reliable Labels via Self-Purification](https://huggingface.co/papers/2509.19091) by Hyeongju Kim, Yechan Yu, June Young Yi, Juheon Lee. 1. **[Swin Transformer](https://huggingface.co/docs/transformers/model_doc/swin)** (from Microsoft) released with the paper [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://huggingface.co/papers/2103.14030) by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo. 1. **[Swin2SR](https://huggingface.co/docs/transformers/model_doc/swin2sr)** (from University of Würzburg) released with the paper [Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration](https://huggingface.co/papers/2209.11345) by Marcos V. Conde, Ui-Jin Choi, Maxime Burchi, Radu Timofte. 1. **[T5](https://huggingface.co/docs/transformers/model_doc/t5)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://huggingface.co/papers/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.