In [None]:
import numpy as np 
from pathlib import Path
import padertorch as pt
import paderbox as pb
import torch
import torchaudio

import ipywidgets as widgets
from onnxruntime import InferenceSession
from pvq_manipulation.models.vits import Vits_NT
from pvq_manipulation.models.ffjord import FFJORD
from IPython.display import display, Audio, clear_output
from paderbox.transform.module_resample import resample_sox
from pvq_manipulation.helper.vad import EnergyVAD
from pvq_manipulation.models.hubert import HubertExtractor, SID_LARGE_LAYER
from pvq_manipulation.helper.creapy_wrapper import process_file

# load TTS model

In [None]:
if torch.cuda.is_available():
    device='cuda'
else:
    device='cpu'

storage_dir_tts = Path("./Saved_models/tts_model/")
tts_model = Vits_NT.load_model(storage_dir_tts, checkpoint="checkpoint.pth")
tts_model.to(device)

# load normalizing flow

In [None]:
storage_dir_normalizing_flow = Path("./Saved_models/norm_flow_creak/")
config_norm_flow = pb.io.load_yaml(storage_dir_normalizing_flow / "config.yaml")
normalizing_flow = FFJORD.load_model(storage_dir_normalizing_flow, checkpoint="ckpt_best_loss.pth")

# load hubert model

In [None]:
hubert_model = HubertExtractor(
    layer=SID_LARGE_LAYER,
    model_name="HUBERT_LARGE",
    backend="torchaudio",
    device=device, 
    storage_dir= # target storage dir hubert model
)

# Example Synthesis

In [None]:
speaker_id = 1034
example_id = "1034_121119_000028_000001"

wav_1 = tts_model.synthesize_from_example({
    'text' : "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", 
    'd_vector_storage_root': f"./Saved_models/Dataset/Embeddings/{speaker_id}/{example_id}.pth"
})
display(Audio(wav_1, rate=24_000, normalize=True))

# Manipulation Block

In [None]:
def get_manipulation(
    example, 
    labels,
    flow, 
    tts_model,    d_vector,

    manipulation_idx=0,
    manipulation_fkt=1,
):
    labels_manipulated = labels.clone()
    labels_manipulated[:,manipulation_idx] += manipulation_fkt

    if config_norm_flow['flag_normalize_d_vectors']:
        speaker_embedding_norm = tts_model.normalize_d_vectors(
            d_vector,
            Path(example['d_vector_storage_root']).parent.parent
        )
    elif config_norm_flow['flag_remove_mean']:
        global_mean = pb.io.load( Path(example['d_vector_storage_root']).parent.parent / "mean.json")
        global_mean = torch.tensor(global_mean, dtype=torch.float32)
        speaker_embedding_norm = (d_vector - global_mean) 
        global_std = pb.io.load(file_path / "std.json")
        global_std = torch.tensor(global_std, dtype=torch.float32)
        speaker_embedding_norm = speaker_embedding_norm / global_std
    else:
        speaker_embedding_norm = d_vector

    output_forward = flow.forward((speaker_embedding_norm.float(), labels))[0]
    sampled_class_manipulated = flow.sample((output_forward, labels_manipulated))[0]

    if config_norm_flow['flag_remove_mean']:
        sampled_class_manipulated = (sampled_class_manipulated * global_std + global_mean) 

    wav = tts_model.synthesize_from_example({
        'text': example['transcription'],
        'd_vector': d_vector.detach().numpy(),
        'd_vector_man': sampled_class_manipulated.detach().numpy(),
        'd_vector_storage_root': example['d_vector_storage_root'],
    })    
    return wav

def extract_speaker_embedding(example):
    audio_data = example['loaded_audio_data']['16_000']    
    with torch.no_grad():
        example = tts_model.speaker_manager.prepare_example({'audio_data': {'observation': audio_data}, **example})
        example = pt.data.utils.collate_fn([example])
        example['features'] = torch.tensor(np.array(example['features']))
        d_vector = tts_model.speaker_manager.forward(example)[0]
    return d_vector

pvq_labels = ['Weight', 'Resonance', 'Breathiness', 'Roughness', 'Loudness', 'Strain', 'Pitch']
def get_creak_label(example):
    audio_data = example['loaded_audio_data']['16_000']
    test, y_pred, included_indices = process_file(audio_data)
    mean_creak = np.mean(y_pred[included_indices])
    return mean_creak * 100


def load_speaker_labels(example, config_norm_flow, reg_stor_dir=Path('./Saved_models/pvq_extractor/')):
    audio_data = torch.tensor(example['loaded_audio_data']['16_000'], dtype=torch.float)[None,:]
    num_samples = torch.tensor([audio_data.shape[-1]])

    if torch.cuda.is_available():
        audio_data = audio_data.cuda()
        num_samples = num_samples.cuda()
    providers = ["CPUExecutionProvider"]
    
    with torch.no_grad():
        features, seq_len = hubert_model(
            audio_data, 
            16_000, 
            sequence_lengths=num_samples,
        )
        
        features = np.mean(features.squeeze(0).detach().cpu().numpy(), axis=-1)

        pvqd_predictions = {}
        for pvq in pvq_labels:
            with open(reg_stor_dir / f"{pvq}.onnx", "rb") as fid:
                onnx = fid.read()
            sess = InferenceSession(onnx, providers=providers)
            pred = sess.run(None, {"X": features[None]})[0].squeeze(1)
            pvqd_predictions[pvq] = pred.tolist()[0]

    pvqd_predictions['Creak_mean'] = get_creak_label(example)
    
    labels = []
    for key in pvq_labels + ["Creak_mean"]:
        labels.append(pvqd_predictions[key]/100)
    return torch.tensor(labels, device=device).float()

def load_audio_files(example):
    observation_loaded, sr = pb.io.load_audio(example['audio_path']['observation'], return_sample_rate=True)
    
    example['loaded_audio_data'] = {}
    observation = resample_sox(observation_loaded, in_rate=sr, out_rate=16_000)
    
    vad = EnergyVAD(sample_rate=16_000)
    if observation.ndim == 1:
        observation = observation[None, :]
        
    observation = vad({'audio_data': observation})['audio_data']
    example['loaded_audio_data']['16_000'] = observation
    
    observation = resample_sox(observation, in_rate=sr, out_rate=24_000)
    vad = EnergyVAD(sample_rate=24_000)
    if observation.ndim == 1:
        observation = observation[None, :]
    observation = vad({'audio_data': observation})['audio_data']
    example['loaded_audio_data']['24_000'] = observation
    return example

# Get example manipulation

In [None]:
speaker_id = 8820
example_id = "8820_294120_000011_000001"

example = {
    'audio_path': {'observation': f"./Saved_models/Dataset/Audio_files/{example_id}.wav"},
    'speaker_id': speaker_id,
    'example_id': example_id,
    'd_vector_storage_root': f"./Saved_models/Dataset/Embeddings/{speaker_id}/{example_id}.pth",
    'transcription': "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
}

example = load_audio_files(example)
d_vector = extract_speaker_embedding(example)
labels = load_speaker_labels(example, config_norm_flow)

wav_manipulated = get_manipulation(
    example=example, 
    d_vector=d_vector, 
    labels=labels[None, :], 
    flow=normalizing_flow,
    tts_model=tts_model,
    manipulation_idx=7,
    manipulation_fkt=3,
)
display(Audio(wav_manipulated, rate=24_000, normalize=True))
display(Audio(example['audio_path']['observation'], rate=24_000, normalize=True))

In [None]:
dataset_dict = pb.io.load_yaml('./Saved_models/Dataset/dataset.yaml')
example_id_widget = widgets.Dropdown(
    options=dataset_dict['dataset'].keys(),
    value='1034_121119_000028_000001', 
    description='Example ID: ',
    style={'description_width': 'initial'}
)

manipulation_idx_widget = widgets.Dropdown(
    options=[('Weight', 0), ('Resonance', 1), ('Breathiness', 2), ('Roughness', 3), ('Creak', 7)],
    value=2, 
    description='Type:',
    style={'description_width': 'initial'}
)

manipulation_fkt_widget = widgets.FloatSlider(
    value=1.0, min=-2.0, max=5.0, step=0.1,
    description='Strength:',
    style={'description_width': 'initial'}
)

transcription_widget = widgets.Text(
    value="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
    placeholder='Type something',
    description='String:',
    disabled=False,
    layout=widgets.Layout(width='900px')
)

run_button = widgets.Button(description="Run Manipulation")

audio_output = widgets.Output()

cached_example_id = None
cached_loaded_example = None
cached_labels = None
cached_d_vector = None

def update_manipulation(b):
    global cached_example_id, cached_loaded_example, cached_labels, cached_d_vector, example_database

    with audio_output:
        clear_output(wait=True)
        display(widgets.Label(f"Processing...."))

    example_id = example_id_widget.value.strip()
    speaker_id = dataset_dict['dataset'][example_id]['speaker_id']

    example = {
        'audio_path': {'observation': f"./Saved_models/Dataset/Audio_files/{example_id}.wav"},
        'd_vector_storage_root': f"./Saved_models/Dataset/Embeddings/{speaker_id}/{example_id}.pth",
        'speaker_id': speaker_id,
        'example_id': example_id,
        'transcription': transcription_widget.value.strip()
    }

    if cached_example_id != example_id:
        with audio_output:
            clear_output(wait=True)
            display(widgets.Label(f"ðŸ”„ Loading new example: {example_id}"))

            cached_loaded_example = load_audio_files(example)
            cached_d_vector = extract_speaker_embedding(cached_loaded_example)
            cached_labels = load_speaker_labels(example, config_norm_flow)
            cached_example_id = example_id
            
    wav_manipulated = get_manipulation(
        example=cached_loaded_example, 
        d_vector=cached_d_vector, 
        labels=cached_labels[None, :], 
        flow=normalizing_flow,
        tts_model=tts_model,
        manipulation_idx=manipulation_idx_widget.value,
        manipulation_fkt=manipulation_fkt_widget.value,
    )
    
    with audio_output:
        clear_output(wait=True) 
        print('Manipulated Speaker')
        display(Audio(wav_manipulated, rate=24_000, normalize=True))
        print('Original Speaker')
        display(Audio(example['audio_path']['observation'], rate=24_000, normalize=True))

run_button.on_click(update_manipulation)
display(example_id_widget, manipulation_fkt_widget, transcription_widget, manipulation_idx_widget, run_button, audio_output)