In [None]:
import sys
sys.path.append('..')

In [None]:
import os
import math
import torch
import random
import numpy as np
from pathlib import Path
from torch.cuda import empty_cache
from IPython.display import Audio
import matplotlib.pyplot as plt
from audiotoken import AudioToken, Tokenizers

from tts.long_infer import AudioSemantic, normalize_text, generate_long, generate
from common import Config as cfg
from common import ACOUSTIC, SEMANTIC, TEXT, DEVICE, ctx

In [None]:
ttslib = AudioSemantic()

Prepare the prompts

In [None]:
prompt_tokens = np.load('../prompts/jenny_short/tokens.npz')
sa_prompt_toks_dict = {
    'source_tokens': prompt_tokens['SEMANTIC'],
    'target_tokens': prompt_tokens['ACOUSTIC']
}

ts_prompt_toks_dict = {
    'source_tokens': prompt_tokens['TEXT'],
    'target_tokens': prompt_tokens['SEMANTIC']
}

In [None]:
sa_prompt_toks_dict['source_tokens'].shape, sa_prompt_toks_dict['target_tokens'].shape

In [None]:
ts_prompt_toks_dict['source_tokens'].shape, ts_prompt_toks_dict['target_tokens'].shape

In [None]:
ttslib.text_tokenizer.decode(prompt_tokens['TEXT'])

Text to semantic

In [None]:
from datasets import load_dataset
ds = load_dataset("roneneldan/TinyStories")
print(len(ds['train']))

In [None]:
k = random.sample(range(len(ds['train'])), 1)[0]
k = 1120689
random_txt = ds['train'][k]['text']

print(k, random_txt)

In [None]:
random_txt = """
Once upon a time, in a cozy little house at the edge of a forest, lived a curious mouse named Pip. Pip loved to explore, but he had never ventured beyond his garden gate.
One sunny morning, Pip decided it was time for an adventure. He packed a tiny backpack with cheese and crackers, then scurried out the gate.
As Pip wandered through the forest, he met a friendly rabbit named Flopsy. "Where are you going?" Flopsy asked.
"I'm on an adventure!" Pip replied excitedly.
"""

In [None]:
sem_toks = ttslib.text_to_semantic_long(
    random_txt,
    max_source_tokens=32,
    source_overlap=16,
    temperature=0.99,
    max_new_tokens=1024,
    prompt_dict=ts_prompt_toks_dict
)

print(sem_toks.shape)

- gen above 8-9 s didnt work well
- max gen as 1200, source toks as 256

- gen above 4-5 s didnt work well
- max gen as 600, source toks as 160

In [None]:
steps = [1, 5, 10, 15, 20]
source_toks = [s*35 for s in steps]
dest_toks = [s*150 for s in steps]
total_toks = [a+b for a, b in zip(source_toks, dest_toks)]

steps, source_toks, dest_toks, total_toks

In [None]:
auds = []
gen_kwargs = {'temperature': 0.95, 'max_source_tokens': 768, 'max_new_tokens': 3072}
# gen_kwargs = {'temperature': 0.8, 'max_source_tokens': 256, 'max_new_tokens': 1024}

stride = 200

for i in range(0, sem_toks.shape[-1], stride):
    start_idx = i
    end_idx = i + stride
    try:
        print(start_idx, end_idx)
        aud = ttslib.semantic_to_audio(sem_toks[start_idx: end_idx], model=ttslib.semantic_acoustic_model_new, **gen_kwargs)
        display(Audio(aud[0], rate=24000))
    except Exception as err:
        print(err)

In [None]:
print(random_txt)

In [None]:
gen_kwargs = {'temperature': 0.95, 'max_source_tokens': 768, 'max_new_tokens': 2048}

for x in [100, 200, 300, 500, 600, 800]:
    aud = ttslib.semantic_to_audio(sem_toks[:x], model=ttslib.semantic_acoustic_model_new, **gen_kwargs)
    display(Audio(aud[0], rate=24000))

In [None]:
gen_kwargs = {'temperature': 0.95, 'max_source_tokens': 768, 'max_new_tokens': 2048}
aud = ttslib.semantic_to_audio(sem_toks[450:600], model=ttslib.semantic_acoustic_model_new, **gen_kwargs)
display(Audio(aud[0], rate=24000))

In [None]:
gen_kwargs = {'temperature': 0.8, 'max_source_tokens': 256, 'max_new_tokens': 1024}
aud = ttslib.semantic_to_audio(sem_toks[450:600], model=ttslib.semantic_acoustic_model, **gen_kwargs)
display(Audio(aud[0], rate=24000))

Semantic to acoustic

In [None]:
sem_toks.shape

In [None]:
acoustic_tokens, st, gt = generate_long(
    model=ttslib.semantic_acoustic_model_new,
    source=SEMANTIC,
    target=ACOUSTIC,
    source_tokens=sem_toks,
    device=device,
    temperature=0.95,
    max_new_tokens=1500,
    max_source_tokens=100,
    source_overlap=50,
    prompt_dict=sa_prompt_toks_dict
)

In [None]:
empty_cache()
wav = ttslib.acoustic_tokenizer.decode(torch.tensor(acoustic_tokens))
display(Audio(wav[0].cpu().numpy(), rate=24000))

In [None]:
for idx in range(len(gt)):
    a = gt[idx] - cfg.OFFSET[ACOUSTIC]
    wav = ttslib.acoustic_tokenizer.decode(torch.tensor(a))
    display(Audio(wav[0].cpu().numpy(), rate=24000))

In [None]:
random_txt

In [None]:
idx = 2

t = gt[idx] - cfg.OFFSET[ACOUSTIC]
cb1 = t[::2] < 1024
cb2 = t[1::2] >= 1024

np.where(cb1 == False), np.where(cb2 == False)

Creating a prompt

! ffmpeg -y -v 0 -i LJ025-0076.wav -acodec libmp3lame -b:a 64k female_prompt_2.wav

In [None]:
acoustic_tokenizer = AudioToken(Tokenizers.acoustic, device=device)
semantic_tokenizer = AudioToken(Tokenizers.semantic_s, device=device)

def replace_consecutive(arr):
    mask = np.concatenate(([True], arr[1:] != arr[:-1]))
    return arr[mask]

In [None]:
def hubert_processor(audio, processor):
    return processor(
        audio,
        sampling_rate=16_000,
        return_tensors='pt'
    ).input_values[0]
processor = Wav2Vec2FeatureExtractor.from_pretrained('voidful/mhubert-base')

In [None]:
prompt_path = 'prompts/lj_female_long/audio.wav'

In [None]:
acoustic_tokenizer = AudioToken(Tokenizers.acoustic, device=device)
semantic_tokenizer = AudioToken(Tokenizers.semantic_s, device=device)

In [None]:
aud = read_audio_file(Path(prompt_path), 16000)
aud = processor(aud, sampling_rate=16000)
aud = aud['input_values'][0]

In [None]:
prompt_aco_toks = acoustic_tokenizer.encode(Path(prompt_path))
prompt_sem_toks = semantic_tokenizer.encode(aud)

prompt_sem_toks = replace_consecutive(prompt_sem_toks[0][0])

flat_aco_toks = prompt_aco_toks[0, :2, :].clone()
flat_aco_toks[1] += 1024
flat_aco_toks = torch.stack([flat_aco_toks[0], flat_aco_toks[1]], dim=1).flatten()

# txt_toks = np.array(ttslib.text_tokenizer.encode('many animals of even complex structure which live parasitically within others are wholly devoid of an alimentary cavity <period>')).astype(np.int64)
txt_toks = np.array(ttslib.text_tokenizer.encode('said meg impatiently <period>')).astype(np.int64)

In [None]:
auds = acoustic_tokenizer.decode(prompt_aco_toks)
Audio(auds[0], rate=24000)

In [None]:
np.savez(
    'prompts/lj_female_long/tokens.npz',
    SEMANTIC=prompt_sem_toks.numpy().astype(np.int64),
    ACOUSTIC=flat_aco_toks.numpy().astype(np.int64), 
    TEXT=txt_toks
)

In [None]:
prev = np.load('prompts/jenny_short/tokens.npz')