Skip to content

Add xcodec2 model#44178

Open
ebezzam wants to merge 86 commits intohuggingface:mainfrom
ebezzam:add-xcodec2
Open

Add xcodec2 model#44178
ebezzam wants to merge 86 commits intohuggingface:mainfrom
ebezzam:add-xcodec2

Conversation

@ebezzam
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam commented Feb 20, 2026

What does this PR do?

Re-opening #37868

TODO

  • recompute expected outputs
  • passthrough code given new conventions
  • check for unused code paths / configuration parameters

Original checkpoint: https://huggingface.co/HKUSTAudio/xcodec2
Original modeling code: https://huggingface.co/HKUSTAudio/xcodec2/blob/main/modeling_xcodec2.py

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ebezzam
Copy link
Copy Markdown
Contributor Author

ebezzam commented Mar 18, 2026

run-slow: xcodec2

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/xcodec2"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 6fd5f248 workflow commit (merge commit)
PR 1fbe78dc branch commit (from PR)
main 24a4dc22 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, dac, higgs_audio_v2_tokenizer, pe_audio, qwen2_5_omni, seamless_m4t, wav2vec2_bert, xcodec, xcodec2

Copy link
Copy Markdown
Contributor Author

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eustlb a self-review review for X-Codec2!

Main things:

  • Unique feature extraction for DAC-like and SeamlessM4T-like input processing, as the model needs both padded audio and spectrogram inputs.
  • New type of components in modular: Xcodec2FiniteScalarQuantization and Xcodec2ISTFTHead (similar to what we saw in the Vocos PR)
  • Small tweaks/fixes for models that Xcodec2 depended on for modular

Draft model page: https://huggingface.co/bezzam/xcodec2

main_input_name = "input_features"
input_modalities = "audio"
supports_gradient_checkpointing = True
_no_split_modules = ["Wav2Vec2BertEncoderLayer"]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To allow loading with device_map="auto"

@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
super()._init_weights(module)
Copy link
Copy Markdown
Contributor Author

@ebezzam ebezzam Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XCodec2 uses a pretrained checkpoint of Wav2Vec2-BERT, but Xcodec2's test test_can_init_all_missing_weights was failing because Embedding wasn't initialized. We can rely on the base _init_weights and also remove some initialization from below

Comment on lines +166 to +168
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=config.hidden_size, eps=1e-6, affine=True)
self.activation1 = nn.SiLU()
self.conv1 = nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=3, stride=1, padding=1)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to PeAudioVideoConvBlock1d but slight differences that don't make modular direct here?

Comment on lines +134 to +139
class SnakeBeta(SnakeBeta):
pass


class AntiAliasedActivation1d(AntiAliasedActivation1d):
pass
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought just importing above would have been enough, but it wasn't generating the classes without this 🤔

Comment on lines +258 to +268
# Back to audio (ISTFT with "same" padding)
time_frames = torch.fft.irfft(spectrogram_complex, self.n_fft, dim=1, norm="backward")
time_frames = time_frames * self.window[None, :, None]
num_frames = spectrogram_complex.shape[-1]
output_size = (num_frames - 1) * self.hop_length + self.win_length
audio = F.fold(
time_frames,
output_size=(1, output_size),
kernel_size=(1, self.win_length),
stride=(1, self.hop_length),
)[:, 0, 0, self.padding : -self.padding]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.istft doesn't support the custom padding needed here for integrations tests to match expected output

Comment on lines +296 to +299
hidden_states = self.finite_scalar_quantization.bound(
hidden_states
) # For consistency with original checkpoint
quantized_out, indices = self.finite_scalar_quantization(hidden_states)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling self.finite_scalar_quantization.bound is a bit redundant, as it's called within self.finite_scalar_quantization(hidden_states). But the original modeling did it and it is needed to match expected outputs.

return hidden_states + residual


class Xcodec2FiniteScalarQuantization(nn.Module):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new component

return codes, indices


class Xcodec2ISTFTHead(nn.Module):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to what we saw in the Vocos PR

@ebezzam ebezzam requested a review from eustlb March 19, 2026 12:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants