# Controlled Accent Conversion - Inference Demo

This notebook demonstrates inference with the Controlled Accent Conversion model.

**Paper**: [Controlled Accent Conversion](https://arxiv.org/abs/2510.10785) (ICASSP 2026)

**Demo**: [Listen to samples](https://claussss.github.io/accent_control_demo/)

## Prerequisites

1. Download checkpoints and stats from [Google Drive](https://drive.google.com/drive/folders/1Pnq_XV5VA_hcIpoOYfbSnLZFA3GKGk1C?usp=sharing)
2. Place `stats/` folder and checkpoint in your project directory
3. Install dependencies: `pip install -r requirements.txt`

## Configuration

Set your paths below:

In [None]:
# ============ CONFIGURE PATHS HERE ============

# Path to the diffusion model checkpoint
CHECKPOINT_PATH = './checkpoints/model_exp_lin_sched_100_step_snr_17.pt'

# Path to stats directory (contains mean/std .pt files)
STATS_PATH = './stats'

# Path to audio file for conversion
AUDIO_PATH = './sample_audio.wav'

# Transcript of the audio (required for phoneme alignment)
TRANSCRIPT = "To my dearest and always appreciated friend I submit myself"

# Device
DEVICE = 'cuda'  # or 'cpu'

# ==============================================

## 1. Import Dependencies

In [None]:
import sys
import os
import math
import random
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
from einops import rearrange
import re

# Phonemizer imports
from phonemizer.separator import Separator
from phonemizer.backend import EspeakBackend

# ASR model for forced alignment
from transformers import (
    Wav2Vec2FeatureExtractor,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2Processor,
    Wav2Vec2ForCTC,
)
from nemo_text_processing.text_normalization.normalize import Normalizer

# FACodec
sys.path.append('../Amphion')
from Amphion.models.codec.ns3_codec import FACodecEncoder, FACodecDecoder
from huggingface_hub import hf_hub_download

# Our modules
sys.path.append('.')
from FACodec_AC.models import DenoisingTransformerModel
from FACodec_AC.config import Config
from FACodec_AC.utils import (
    get_phone_forced_alignment, 
    interpolate_alignment, 
    QuantizerNames, 
    get_z_from_indx, 
    snap_latent,
    standardize,
    destandardize
)
from IPython.display import Audio

print(f"Using device: {DEVICE}")

## 2. Initialize FACodec Encoder/Decoder

FACodec is used to encode audio into discrete representations and decode them back.

In [None]:
# Initialize FACodec
fa_encoder = FACodecEncoder(
    ngf=32,
    up_ratios=[2, 4, 5, 5],
    out_channels=256
)
fa_decoder = FACodecDecoder(
    in_channels=256,
    upsample_initial_channel=1024,
    ngf=32,
    up_ratios=[5, 5, 4, 2],
    vq_num_q_c=2,
    vq_num_q_p=1,
    vq_num_q_r=3,
    vq_dim=256,
    codebook_dim=8,
    codebook_size_prosody=10,
    codebook_size_content=10,
    codebook_size_residual=10,
    use_gr_x_timbre=True,
    use_gr_residual_f0=True,
    use_gr_residual_phone=True,
)

# Load pretrained weights
encoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_encoder.bin")
decoder_ckpt = hf_hub_download(repo_id="amphion/naturalspeech3_facodec", filename="ns3_facodec_decoder.bin")
fa_encoder.load_state_dict(torch.load(encoder_ckpt))
fa_decoder.load_state_dict(torch.load(decoder_ckpt))
fa_encoder.eval()
fa_decoder.eval()
fa_encoder = fa_encoder.to(DEVICE)
fa_decoder = fa_decoder.to(DEVICE)

print("FACodec loaded successfully!")

## 3. Initialize Phoneme Alignment Pipeline

We use Wav2Vec2 for forced phoneme alignment.

In [None]:
# Load ASR model for forced alignment
fe = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
tok = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft", use_fast=False)
proc = Wav2Vec2Processor(feature_extractor=fe, tokenizer=tok)
asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft").to(DEVICE)
asr_model.eval()
target_sr = fe.sampling_rate  # 16000

# Phonemizer setup
REMOVE = ''.join(['ʲ', 'ʷ', 'ʰ', 'ʱ', 'ˠ', 'ˤ', 'ʶ', 'ʵ'])
pipeline = {
    "normaliser": Normalizer(lang="en", input_case="cased", deterministic=True, post_process=True),
    "regex": re.compile(r"[^a-z' ]"),
    "sep": Separator(phone=" ", word="|", syllable=""),
    "backend": EspeakBackend('en-us'),
    'wav2vec_processor': proc,
    'wav2vec_model': asr_model,
    'DROP_RE': re.compile('[%s]' % re.escape(REMOVE))
}

print("Phoneme alignment pipeline ready!")

## 4. Load Diffusion Model and Stats

In [None]:
# Initialize the diffusion model
diffusion_model = DenoisingTransformerModel(
    d_model=Config.d_model,
    nhead=Config.nhead,
    num_layers=Config.num_layers,
    d_ff=Config.d_ff,
    dropout=Config.dropout,
    max_seq_len=Config.max_seq_len,
    FACodec_dim=Config.FACodec_dim,
    phone_vocab_size=Config.PHONE_VOCAB_SIZE,
    num_steps=100,  # Number of diffusion steps
)

# Load checkpoint
diffusion_model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))
diffusion_model = diffusion_model.to(DEVICE)
diffusion_model.eval()

# Load normalization stats
zc1_mean = torch.load(os.path.join(STATS_PATH, 'mean_zc1_indx.pt')).to(DEVICE)
zc1_std = torch.load(os.path.join(STATS_PATH, 'std_zc1_indx.pt')).to(DEVICE)
zc2_mean = torch.load(os.path.join(STATS_PATH, 'mean_zc2_indx.pt')).to(DEVICE)
zc2_std = torch.load(os.path.join(STATS_PATH, 'std_zc2_indx.pt')).to(DEVICE)

print(f"Diffusion model loaded from: {CHECKPOINT_PATH}")
print(f"Stats loaded from: {STATS_PATH}")

## 5. Load and Process Audio

In [None]:
# Load audio
wav_waveform, wav_sr = torchaudio.load(AUDIO_PATH)

# Convert to mono if stereo
if wav_waveform.shape[0] > 1:
    wav_waveform = wav_waveform.mean(dim=0, keepdim=True)

# Resample to 16kHz if needed
if wav_sr != 16000:
    resample = torchaudio.transforms.Resample(orig_freq=wav_sr, new_freq=16000)
    wav_waveform = resample(wav_waveform)

wav_waveform = wav_waveform.to(DEVICE)
print(f"Audio shape: {wav_waveform.shape}")

# Play original audio
Audio(wav_waveform.cpu().numpy(), rate=16000)

## 6. Extract FACodec Representations

In [None]:
# Encode with FACodec
with torch.no_grad():
    h_input = fa_encoder(wav_waveform[None, :, :])
    vq_post_emb, vq_id, _, quantized_arr, spk_embs = fa_decoder(h_input, eval_vq=False, vq=True)

# Content indices
zc1_indx = vq_id[1]
seq_len = zc1_indx.shape[1]
print(f"Sequence length: {seq_len}")

## 7. Get Phoneme Forced Alignment

In [None]:
# Prepare transcript
file_id = os.path.splitext(os.path.basename(AUDIO_PATH))[0]
transcript_dict = {file_id: TRANSCRIPT}
audio_folder = os.path.dirname(AUDIO_PATH)

# Get forced alignment
predicted_ids, _, frames_score, _ = get_phone_forced_alignment(
    embedding_path=file_id + '.pt',
    audio_folder=audio_folder,
    transcript_metadata=transcript_dict,
    device=DEVICE,
    target_sr=target_sr,
    pipeline=pipeline,
    inference=True
)

# Interpolate to match FACodec sequence length
interpolated_phone_ids = interpolate_alignment(predicted_ids, seq_len)[0]
print(f"Phone IDs shape: {interpolated_phone_ids.shape}")

## 8. DDIM Sampling

We use DDIM (Denoising Diffusion Implicit Models) for faster sampling.

In [None]:
def make_ddim_schedule(num_train_steps: int, num_infer_steps: int) -> torch.LongTensor:
    """Create DDIM schedule for inference."""
    t = np.linspace(num_train_steps - 1, 0, num_infer_steps)
    return torch.tensor(np.round(t), dtype=torch.long, device=DEVICE)


def sample_ddim(model, zc1_clean, padded_phone_ids, padding_mask, schedule, start_idx=0, seed=42):
    """DDIM sampling for accent conversion."""
    # Set seeds for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    model.eval()
    bsz, C, L = zc1_clean.shape

    sqrt_abar = model.sqrt_abar
    sqrt_1mabar = model.sqrt_1mabar

    # Forward diffuse to starting step
    t_start = schedule[start_idx]
    eps0 = torch.randn_like(zc1_clean)
    z_t = sqrt_abar[t_start].view(1, 1, 1) * zc1_clean + sqrt_1mabar[t_start].view(1, 1, 1) * eps0

    # Reverse DDIM sampling
    with torch.no_grad():
        for i in range(start_idx, len(schedule)):
            t = schedule[i]
            t_idx = torch.full((bsz,), t, device=DEVICE, dtype=torch.long)

            eps_pred, _ = model(
                zc1_noisy=z_t,
                padded_phone_ids=padded_phone_ids,
                t=t_idx,
                padding_mask=padding_mask
            )

            sa = sqrt_abar[t].view(1, 1, 1)
            s1a = sqrt_1mabar[t].view(1, 1, 1)
            x0_hat = (z_t - s1a * eps_pred) / sa

            if i < len(schedule) - 1:
                t_prev = schedule[i + 1]
                z_t = sqrt_abar[t_prev].view(1, 1, 1) * x0_hat + sqrt_1mabar[t_prev].view(1, 1, 1) * eps_pred
            else:
                z_t = x0_hat

    zc1_pred = z_t

    # Predict zc2 from clean zc1
    t_zero = torch.zeros((bsz,), device=DEVICE, dtype=torch.long)
    _, zc2_pred = model(
        zc1_noisy=zc1_pred,
        padded_phone_ids=padded_phone_ids,
        t=t_zero,
        padding_mask=padding_mask
    )

    return zc1_pred, zc2_pred

## 9. Run Accent Conversion

In [None]:
# Prepare input
x_input = get_z_from_indx(
    vq_id[1], fa_decoder=fa_decoder,
    layer=0, quantizer_num=QuantizerNames.content, dim=Config.FACodec_dim
)
zc1_input_normalized = standardize(x_input, zc1_mean, zc1_std)

# Create DDIM schedule
T = 100  # Training steps
K = 100  # Inference steps
schedule = make_ddim_schedule(T, K)

# Sampling parameters
noise_level = 0  # 0 = full diffusion from noise, higher = less noise
seed = 42

# Prepare tensors
bsz, C, seq_len = zc1_input_normalized.shape
padded_phone_ids = interpolated_phone_ids
padding_mask = torch.zeros(bsz, seq_len, dtype=torch.bool, device=DEVICE)

# Run DDIM sampling
zc1_pred, zc2_pred = sample_ddim(
    diffusion_model,
    zc1_input_normalized,
    padded_phone_ids,
    padding_mask,
    schedule,
    start_idx=noise_level,
    seed=seed
)

# Denormalize
zc1_pred = destandardize(zc1_pred, zc1_mean, zc1_std)
zc2_pred = destandardize(zc2_pred, zc2_mean, zc2_std)

print("Diffusion sampling complete!")

## 10. Reconstruct Audio

In [None]:
# Snap to codebook and project to 256-dim
zc1_pred_snapped = snap_latent(zc1_pred.transpose(1, 2), fa_decoder, layer=0)
zc1_pred_256 = fa_decoder.quantizer[1].layers[0].out_proj(zc1_pred_snapped).transpose(1, 2)
zc2_pred_256 = fa_decoder.quantizer[1].layers[1].out_proj(zc2_pred.transpose(1, 2)).transpose(1, 2)

# Combine content codes
x_reconstructed = zc1_pred_256 + zc2_pred_256

# Decode with FACodec
with torch.no_grad():
    # Combine: prosody + predicted content + acoustic residuals
    final_code = quantized_arr[0] + x_reconstructed + quantized_arr[2]
    wav_output = fa_decoder.inference(final_code, spk_embs)

print("Audio reconstruction complete!")

## 11. Listen to Results

In [None]:
print("Original Audio:")
Audio(wav_waveform.cpu().numpy(), rate=16000)

In [None]:
print("Converted Audio:")
Audio(wav_output.squeeze().cpu().numpy(), rate=16000)

## 12. Save Output (Optional)

In [None]:
# Save converted audio
output_path = AUDIO_PATH.replace('.wav', '_converted.wav')
torchaudio.save(output_path, wav_output.squeeze(0).cpu(), 16000)
print(f"Saved to: {output_path}")