In [None]:
import numpy as np 
from pathlib import Path
import padertorch as pt
import paderbox as pb
import time
import torch
import torchaudio
import ipywidgets as widgets
from onnxruntime import InferenceSession
from pvq_manipulation.models.vits import Vits_NT
from pvq_manipulation.models.ffjord import FFJORD
from IPython.display import display, Audio, clear_output
from pvq_manipulation.models.hubert import HubertExtractor, SID_LARGE_LAYER
from paderbox.transform.module_resample import resample_sox
from pvq_manipulation.helper.vad import EnergyVAD
from train_tts_nt.helper.utils import rms_norm

# load TTS model

In [None]:
storage_dir_tts = Path("./Saved_models/tts_model/")
tts_model = Vits_NT.load_model(storage_dir_tts, checkpoint="checkpoint_390000.pth")

# load normalizing flow

In [None]:
storage_dir_normalizing_flow = Path("./Saved_models/norm_flow")
config_norm_flow = pb.io.load_yaml(storage_dir_normalizing_flow / "config.yaml")
normalizing_flow = FFJORD.load_model(storage_dir_normalizing_flow, checkpoint="checkpoints/ckpt_best_loss.pth")

# load hubert features model

In [None]:
hubert_model = HubertExtractor(
    layer=SID_LARGE_LAYER,
    model_name="HUBERT_LARGE",
    backend="torchaudio",
    device='cpu', 
    storage_dir='/net/vol/rautenberg/storage/hubert'# target storage dir hubert model
)

# Example Synthesis

In [None]:
speaker_id = 1034
example_id = "1034_121119_000028_000001"

wav_1 = tts_model.synthesize_from_example({
    'text' : "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", 
    'd_vector_storage_root': f"./Saved_models/Dataset/Embeddings/{speaker_id}/{example_id}.pth"
})
display(Audio(wav_1, rate=24_000, normalize=True))

# Manipulation Block

In [None]:
def get_manipulation(
    example, 
    d_vector,
    labels,
    flow, 
    tts_model,
    manipulation_idx=0,
    manipulation_fkt=1,
):
    labels_manipulated = labels.clone()
    labels_manipulated[:,manipulation_idx] += manipulation_fkt
    
    output_forward = flow.forward((d_vector.float(), labels))[0]
    sampled_class_manipulated = flow.sample((output_forward, labels_manipulated))[0]

    wav = tts_model.synthesize_from_example({
        'text': "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
        'd_vector': d_vector.detach().numpy(),
        'd_vector_man': sampled_class_manipulated.detach().numpy(),
    })    
    return wav

def extract_speaker_embedding(example):
    observation, sr = pb.io.load_audio(example['audio_path']['observation'], return_sample_rate=True)
    observation = resample_sox(observation, in_rate=sr, out_rate=16_000)
    
    vad = EnergyVAD(sample_rate=16_000)
    if observation.ndim == 1:
        observation = observation[None, :]
    
    observation = vad({'audio_data': observation})['audio_data']
    
    with torch.no_grad():
        example = tts_model.speaker_manager.prepare_example({'audio_data': {'observation': observation}, **example})
        example = pt.data.utils.collate_fn([example])
        example['features'] = torch.tensor(np.array(example['features']))
        d_vector = tts_model.speaker_manager.forward(example)[0]
    return d_vector

In [None]:
def load_speaker_labels(example, config_norm_flow, reg_stor_dir=Path('./Saved_models/pvq_extractor/')):
    audio, _ = torchaudio.load(example['audio_path']['observation'])
    num_samples = torch.tensor([audio.shape[-1]])

    if torch.cuda.is_available():
        audio = audio.cuda()
        num_samples = num_samples.cuda()
    providers = ["CPUExecutionProvider"]

    with torch.no_grad():
        features, seq_len = hubert_model(
            audio, 
            24_000, 
            sequence_lengths=num_samples,
        )
        features = np.mean(features.squeeze(0).detach().cpu().numpy(), axis=-1)

        pvqd_predictions = {}
        for pvq in ['Breathiness', 'Loudness', 'Pitch', 'Resonance', 'Roughness', 'Strain', 'Weight']:
            with open(reg_stor_dir / f"{pvq}.onnx", "rb") as fid:
                onnx = fid.read()
            sess = InferenceSession(onnx, providers=providers)
            pred = sess.run(None, {"X": features[None]})[0].squeeze(1)
            pvqd_predictions[pvq] = pred.tolist()[0]
    labels = []
    for key in config_norm_flow['speaker_conditioning']:
        labels.append(pvqd_predictions[key]/100)
    return torch.tensor(labels)

# Get example manipulation

In [None]:
example = {
    'audio_path': {'observation': "./Saved_models/Dataset/Audio_files/1034_121119_000028_000001.wav"},
    'speaker_id': 1034,
    'example_id': "1034_121119_000028_000001",
}

d_vector = extract_speaker_embedding(example)
labels = load_speaker_labels(example, config_norm_flow)

wav_manipulated = get_manipulation(
    example=example, 
    d_vector=d_vector, 
    labels=labels[None, :], 
    flow=normalizing_flow,
    tts_model=tts_model,
    manipulation_idx=0,
    manipulation_fkt=1,
)

In [None]:
example = {
    'audio_path': {'observation': "./Saved_models/Dataset/Audio_files/1034_121119_000028_000001.wav"},
    'speaker_id': 1034,
    'example_id': "1034_121119_000028_000001",
}

label_options = ['Weight', 'Resonance', 'Breathiness', 'Roughness', 'Loudness', 'Strain', 'Pitch']

manipulation_idx_widget = widgets.Dropdown(
    options=[(label, i) for i, label in enumerate(label_options)],
    value=2,  # Standardwert: Breathiness
    description='Type:',
    style={'description_width': 'initial'}
)

manipulation_fkt_widget = widgets.FloatSlider(
    value=1.0, min=-2.0, max=2.0, step=0.1,
    description='Strength:',
    style={'description_width': 'initial'}
)

run_button = widgets.Button(description="Run Manipulation")

audio_output = widgets.Output()

def update_manipulation(b):
    manipulation_idx = manipulation_idx_widget.value
    manipulation_fkt = manipulation_fkt_widget.value
    
    d_vector = extract_speaker_embedding(example)
    labels = load_speaker_labels(example, config_norm_flow)

    with audio_output:
        clear_output(wait=True)
        display(widgets.Label("Processing..."))
        
    time.sleep(1)  
    
    wav_manipulated = get_manipulation(
        example=example, 
        d_vector=d_vector, 
        labels=labels[None, :], 
        flow=normalizing_flow,
        tts_model=tts_model,
        manipulation_idx=manipulation_idx,
        manipulation_fkt=manipulation_fkt,
    )
    
    with audio_output:
        clear_output(wait=True) 
        display(Audio(wav_manipulated, rate=24_000, normalize=True))
        display(Audio(example['audio_path']['observation'], rate=24_000, normalize=True))

    print(f"Manipulated {label_options[manipulation_idx]} with strength {manipulation_fkt}")

run_button.on_click(update_manipulation)
display(manipulation_idx_widget, manipulation_fkt_widget, run_button, audio_output)