In [None]:
import os
import math
import torch
import random
import numpy as np
from pathlib import Path
from IPython.display import Audio
import matplotlib.pyplot as plt
from audiotoken import AudioToken, Tokenizers

from tts.long_infer import AudioSemantic, normalize_text, generate_long
from tts.infer import AudioSemantic as VanillaAudioSemantic
from common import Config as cfg
from common import ACOUSTIC, SEMANTIC, TEXT, ctx

In [None]:
ttslib = AudioSemantic()
vanilla_ttslib = VanillaAudioSemantic()
acoustic_tokenizer = AudioToken(Tokenizers.acoustic, device='cuda:0')
semantic_tokenizer = AudioToken(Tokenizers.semantic_s, device='cuda:0')

In [None]:
def replace_consecutive(arr):
    mask = np.concatenate(([True], arr[1:] != arr[:-1]))
    return arr[mask]

Prepare the prompts

In [None]:
prompt_aco_toks = acoustic_tokenizer.encode(Path('prompts/jenny_prompt.wav'))
prompt_sem_toks = semantic_tokenizer.encode(Path('prompts/jenny_prompt.wav'))

prompt_sem_toks = replace_consecutive(prompt_sem_toks[0][0])

flat_aco_toks = prompt_aco_toks[0, :2, :].clone()
flat_aco_toks[1] += 1024
flat_aco_toks = torch.stack([flat_aco_toks[0], flat_aco_toks[1]], dim=1).flatten()

sa_prompt_toks_dict = {
    'source_tokens': prompt_sem_toks.numpy().astype(np.int64),
    'target_tokens': flat_aco_toks.numpy().astype(np.int64)
}

ts_prompt_toks_dict = {
    # 'source_tokens': np.array(ttslib.text_tokenizer.encode('many animals of even complex structure which live parasitically within others are wholly devoid of an alimentary cavity <period>')).astype(np.int64),
    'source_tokens': np.array(ttslib.text_tokenizer.encode('said meg impatiently <period>')).astype(np.int64),
    'target_tokens': prompt_sem_toks.numpy().astype(np.int64)
}

In [None]:
sa_prompt_toks_dict['source_tokens'].shape, sa_prompt_toks_dict['target_tokens'].shape

In [None]:
ts_prompt_toks_dict['source_tokens'].shape, ts_prompt_toks_dict['target_tokens'].shape

In [None]:
auds = acoustic_tokenizer.decode(prompt_aco_toks)
Audio(auds[0], rate=24000)

Text to semantic

In [None]:
from datasets import load_dataset
ds = load_dataset("roneneldan/TinyStories")
print(len(ds['train']))

In [None]:
k = random.sample(range(len(ds['train'])), 1)[0]
random_txt = ds['train'][k]['text']

print(random_txt)

In [None]:
sem_toks = ttslib.text_to_semantic_long(
    random_txt,
    max_source_tokens=32,
    source_overlap=16,
    temperature=0.99,
    max_new_tokens=1024,
    prompt_dict=ts_prompt_toks_dict
)

print(sem_toks.shape)

In [None]:
auds = []

for i in range(0, sem_toks.shape[-1], 150):
    start_idx = i
    end_idx = i + 150
    try:
        print(start_idx, end_idx)
        aud = vanilla_ttslib.semantic_to_audio(sem_toks[start_idx: end_idx])
        print(aud.shape)
        auds.append(aud)
        display(Audio(aud[0], rate=24000))
    except Exception as err:
        print(err)

In [None]:
aud = vanilla_ttslib.semantic_to_audio(sem_toks[1050:])
display(Audio(aud[0], rate=24000))

Semantic to acoustic

In [None]:
acoustic_tokens, st, gt = generate_long(
    model=ttslib.semantic_acoustic_model,
    source=SEMANTIC,
    target=ACOUSTIC,
    source_tokens=sem_toks,
    device='cuda:0',
    max_new_tokens=1024,
    temperature=0.9,
    top_k=100,
    max_source_tokens=128,
    source_overlap=64,
    prompt_dict=sa_prompt_toks_dict
)
print(acoustic_tokens.shape)

In [None]:
from torch.cuda import empty_cache
empty_cache()

In [None]:
wav = ttslib.acoustic_tokenizer.decode(torch.tensor(acoustic_tokens))
display(Audio(wav[0].cpu().numpy(), rate=24000))
empty_cache()
auds = acoustic_tokenizer.decode(prompt_aco_toks)
display(Audio(auds[0], rate=24000))

In [None]:
random_txt

In [None]:
x = st[1]

In [None]:
temp_sem = (x[0, :256] - cfg.OFFSET[SEMANTIC])
temp_aco = vanilla_ttslib.semantic_to_audio(temp_sem.cpu().numpy())
Audio(temp_aco[0], rate=24000)

In [None]:
z = gt[1] - cfg.OFFSET[ACOUSTIC]

In [None]:
z[::2] > 1024

Creating a prompt

! ffmpeg -y -v 0 -i LJ025-0076.wav -acodec libmp3lame -b:a 64k female_prompt_2.wav

In [None]:
prompt_aco_toks = acoustic_tokenizer.encode(Path('prompts/lj_female_long/female_prompt_2.wav'))
prompt_sem_toks = semantic_tokenizer.encode(Path('prompts/lj_female_long/female_prompt_2.wav'))

prompt_sem_toks = replace_consecutive(prompt_sem_toks[0][0])

flat_aco_toks = prompt_aco_toks[0, :2, :].clone()
flat_aco_toks[1] += 1024
flat_aco_toks = torch.stack([flat_aco_toks[0], flat_aco_toks[1]], dim=1).flatten()

txt_toks = np.array(ttslib.text_tokenizer.encode('many animals of even complex structure which live parasitically within others are wholly devoid of an alimentary cavity <period>')).astype(np.int64)

In [None]:
auds = acoustic_tokenizer.decode(prompt_aco_toks)
Audio(auds[0], rate=24000)

In [None]:
np.savez(
    'lj_female_long.npz',
    semantic_tokens=prompt_sem_toks.numpy().astype(np.int64),
    acoustic_tokens=flat_aco_toks.numpy().astype(np.int64), 
    text_tokens=txt_toks
)