In [None]:
import os
import torch
import math
import numpy as np
from pathlib import Path
from IPython.display import Audio
import matplotlib.pyplot as plt

from audiotoken import AudioToken, Tokenizers

from tts.long_infer import AudioSemantic, normalize_text, generate_long
from tts.infer import AudioSemantic as VanillaAudioSemantic
from common import Config as cfg
from common import ACOUSTIC, SEMANTIC, TEXT, ctx

In [None]:
ttslib = AudioSemantic()
vanilla_ttslib = VanillaAudioSemantic()
acoustic_tokenizer = AudioToken(Tokenizers.acoustic, device='cuda:0')
semantic_tokenizer = AudioToken(Tokenizers.semantic_s, device='cuda:0')

In [None]:
def replace_consecutive(arr):
    mask = np.concatenate(([True], arr[1:] != arr[:-1]))
    return arr[mask]

Prepare the prompts

In [None]:
prompt_aco_toks = acoustic_tokenizer.encode(Path('prompts/jenny_prompt.wav'))
prompt_sem_toks = semantic_tokenizer.encode(Path('prompts/jenny_prompt.wav'))

prompt_sem_toks = replace_consecutive(prompt_sem_toks[0][0])
prompt_aco_toks.shape, prompt_sem_toks.shape

flat_aco_toks = prompt_aco_toks[0, :2, :].clone()
flat_aco_toks[1] += 1024
flat_aco_toks = torch.stack([flat_aco_toks[0], flat_aco_toks[1]], dim=1).flatten()

sa_prompt_toks_dict = {
    'source_tokens': prompt_sem_toks.numpy(),
    'target_tokens': flat_aco_toks.numpy()
}

In [None]:
prompt_aco_toks.shape, prompt_sem_toks.shape

In [None]:
auds = acoustic_tokenizer.decode(prompt_aco_toks)
Audio(auds[0], rate=24000)

In [None]:
ts_prompt_toks_dict = {
    'source_tokens': np.array(ttslib.text_tokenizer.encode('said meg impatiently <period>')),
    'target_tokens': prompt_sem_toks.numpy()
}

Text to semantic

In [None]:
txt1 = "our adventure began in paris <period>"
txt2 = "the eiffel tower amazed us <period>"
txt3 = "we enjoyed cafes and croissants <period>"
txt4 = "the louvres art was stunning <period>"
txt5 = "we finally dived in a swimming pool <period>"

In [None]:
txt1 = "the breeze was gentle <comma> rustling the leaves of the trees <period>"
txt2 = "it was a perfect evening to take a leisurely stroll <period>"
txt3 = "every step on the gravel path felt like a soothing rhythm <period>"
txt4 = "matching the tranquility of the surroundings <period>"
txt5 = "as the sky shifted from orange to deep purple <period>"
txt6 = "the first stars began to appear twinkling like tiny diamonds <period>"

In [None]:
s = vanilla_ttslib.text_to_semantic(txt1)
a = vanilla_ttslib.semantic_to_audio(s)
Audio(a[0], rate=24000)

In [None]:
sem_toks_diff = []
for t in [txt1, txt2, txt3, txt4, txt5]:
    s, _, _ = generate_long(
        model=ttslib.text_semantic_model,
        source=TEXT,
        target=SEMANTIC,
        source_tokens=np.asarray(ttslib.text_tokenizer.encode(t)),
        max_source_tokens=16,
        source_overlap=8,
        device='cuda:0',
        temperature=0.8,
        prompt_dict=ts_prompt_toks_dict
    )
    sem_toks_diff.extend(s)
    print(s.shape, replace_consecutive(s).shape)
    aud = vanilla_ttslib.semantic_to_audio(s)
    display(Audio(aud[0], rate=24000))

sem_toks_diff = np.array(sem_toks_diff)
sem_toks_diff.shape

In [None]:
sentences = ' '.join([txt1, txt2, txt3])
sentence_tokens = np.asarray(ttslib.text_tokenizer.encode(sentences))

sem_toks, st, gt = generate_long(
    model=ttslib.text_semantic_model,
    source=TEXT,
    target=SEMANTIC,
    source_tokens=sentence_tokens,
    max_source_tokens=16,
    source_overlap=8,
    device='cuda:0',
    temperature=0.8,
    prompt_dict=ts_prompt_toks_dict
)

In [None]:
plt.hist(sem_toks)
plt.hist(sem_toks_diff, alpha=0.5)
plt.show()

In [None]:
auds = []

for i in range(0, sem_toks_diff.shape[-1], 150):
    start_idx = i
    end_idx = i + 150
    try:
        print(start_idx, end_idx)
        aud = vanilla_ttslib.semantic_to_audio(sem_toks_diff[start_idx: end_idx])
        print(aud.shape)
        auds.append(aud)
        display(Audio(aud[0], rate=24000))
    except Exception as err:
        print(err)

In [None]:
aud = vanilla_ttslib.semantic_to_audio(sem_toks_diff[300:])
display(Audio(aud[0], rate=24000))

Semantic to acoustic

In [None]:
acoustic_tokens, st, gt = generate_long(
    model=ttslib.semantic_acoustic_model,
    source=SEMANTIC,
    target=ACOUSTIC,
    source_tokens=sem_toks_diff,
    device='cuda:0',
    temperature=0.9,
    top_k=50,
    max_source_tokens=128,
    source_overlap=64,
    prompt_dict=sa_prompt_toks_dict
)
print(acoustic_tokens.shape)

In [None]:
wav = ttslib.acoustic_tokenizer.decode(torch.tensor(acoustic_tokens))
display(Audio(wav[0].cpu().numpy(), rate=24000))

auds = acoustic_tokenizer.decode(prompt_aco_toks)
display(Audio(auds[0], rate=24000))

In [None]:
x = st[1]

In [None]:
temp_sem = (x[0, :128] - cfg.OFFSET[SEMANTIC])
temp_aco = vanilla_ttslib.semantic_to_audio(temp_sem.cpu().numpy())
Audio(temp_aco[0], rate=24000)

In [None]:
z = gt[1] - cfg.OFFSET[ACOUSTIC]

In [None]:
z[::2] > 1024