In [None]:
import os
import torch
import numpy as np
import math
from pathlib import Path

from IPython.display import Audio
from audiotoken import AudioToken, Tokenizers

from tts.long_infer import AudioSemantic
from common import Config as cfg
from common import ACOUSTIC, SEMANTIC, ctx
from tts.long_infer import generate as aco_generate

In [None]:
ttslib = AudioSemantic()
acoustic_tokenizer = AudioToken(Tokenizers.acoustic, device='cuda:0')
semantic_tokenizer = AudioToken(Tokenizers.semantic_s, device='cuda:0')

In [None]:
def replace_consecutive(arr):
    mask = np.concatenate(([True], arr[1:] != arr[:-1]))
    return arr[mask]

- Generate semantic tokens of the prompt
- Use ttslib to generate corresponding acoustic tokens of 2 codebooks

In [None]:
aco_toks = acoustic_tokenizer.encode(Path('female_prompt_short.wav'))
sem_toks = semantic_tokenizer.encode(Path('female_prompt_short.wav'))

sem_toks = replace_consecutive(sem_toks[0][0])
aco_toks.shape, sem_toks.shape

In [None]:
auds = acoustic_tokenizer.decode(aco_toks[:, :, :150])
Audio(auds[0], rate=24000)

In [None]:
aco_gen_toks = aco_generate(
    model=ttslib.semantic_acoustic_model, 
    source_tokens=sem_toks[:64].numpy(),
    source=SEMANTIC,
    target=ACOUSTIC
)
aco_gen_toks.shape

In [None]:
# Generation from original semantic tokens
auds = ttslib.semantic_to_audio(sem_toks.numpy())
Audio(auds[0], rate=24000)

In [None]:
# Generation from intermediatery 2 codebook
auds = ttslib.acoustic_tokenizer.decode(torch.tensor(aco_gen_toks))
Audio(auds[0], rate=24000)

--

In [None]:
txt1 = "my name is romit <comma> and i am trying to build indri <period>"
txt2 = "the breeze was gentle <comma> rustling the leaves of the trees as birds chirped softly in the distance <period>"
txt3 = "it was a perfect evening to take a leisurely stroll <comma> letting the calmness of nature wash over you <period>"
txt4 = "every step on the gravel path felt like a soothing rhythm <comma> matching the tranquility of the surroundings <period>"
txt5 = "as the sky shifted from orange to deep purple <comma> the first stars began to appear <comma> twinkling like tiny diamonds in the vastness above <period>"

In [None]:
sem_toks = ttslib.text_to_semantic(' '.join([txt1, txt2, txt3]))
print(sem_toks.shape, np.unique(sem_toks).shape[0])

In [None]:
sem_toks_diff = [ttslib.text_tokenizer.encode(t) for t in [txt1, txt2, txt3, txt4, txt5]]

In [None]:
[len(s) for s in sem_toks_diff]

In [None]:
for idx, t in enumerate([txt1, txt2, txt3, txt4, txt5]):
    try:
        s = ttslib.text_to_semantic(t)
        print(s.shape)
        aud = ttslib.semantic_to_audio(s)
        print(aud.shape)
        display(Audio(aud[0], rate=24000))
    except Exception as err:
        print(f'err at {idx}, {err}')

In [None]:
auds = []

for i in range(5):
    try:
        aud = ttslib.semantic_to_audio(sem_toks[100:300])
        print(aud.shape)
        auds.append(aud)
    except:
        continue


for aud in auds:
    display(Audio(aud[0], rate=24000))

legacy gen

In [None]:
auds = []

for i in range(10):
    try:
        aud = ttslib.semantic_to_audio_long(sem_toks)
        print(aud.shape)
        auds.append(aud)
    except:
        continue

In [None]:
for aud in auds:
    display(Audio(aud[0], rate=24000))

new gen

In [None]:
def gen_new(model, source, target, source_tokens, device):
    source_tokens = source_tokens + cfg.OFFSET[source]
    max_source_tokens = cfg.max_source_tokens//2

    source_overlap = 64
    target_overlap = 0
    source_stride = max_source_tokens - source_overlap

    # Initialize as empty
    target_tokens = np.asarray([])

    print(
        f'Source, tokens shape: {source_tokens.shape}, overlap: {source_overlap}, stride: {source_stride}, max tokens: {max_source_tokens}'
    )

    for idx in range(0, len(source_tokens), source_stride):
        end_idx = idx + max_source_tokens
        source_cut = source_tokens[idx: end_idx]
        target_cut = target_tokens[-target_overlap:]

        input_tokens = np.hstack([
            source_cut,
            cfg.INFER_TOKEN[target],
            target_cut
        ])

        input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=device)[None, ...]

        print(f'Source tokens shape: {input_tokens.shape}, start idx: {idx}, end idx: {end_idx}')
        print(f'Target cut shape: {target_cut.shape}, overlap: {target_overlap}')
        
        with torch.no_grad():
            with ctx:
                new_target_tokens = model.generate(
                    input_tokens,
                    1024,
                    temperature=0.8,
                    top_k=100,
                    stop_token=cfg.STOP_TOKEN[target]
                ).detach().cpu().numpy()[0]
                print(f'Gen shape: {new_target_tokens.shape}')

        new_target_tokens = new_target_tokens[input_tokens.shape[-1]:]

        # Update the target overlap ratio, for x toks, we generate y toks
        num_source_new_toks = end_idx-idx
        if idx:
            num_source_new_toks = end_idx-idx-source_overlap
        target_overlap = source_overlap * new_target_tokens.shape[-1]/num_source_new_toks
        target_overlap = math.ceil(target_overlap)
        print(f'Source toks: {num_source_new_toks}, New target shape: {new_target_tokens.shape}, overlap: {target_overlap}')
        # Merge into existing target tokens
        target_tokens = np.hstack([target_tokens, new_target_tokens])
        print(f'Overall target shape: {target_tokens.shape}')

        print('\n')

    target_tokens = target_tokens - cfg.OFFSET[target]
    return target_tokens

In [None]:
prompt_aco_toks = aco_toks[0, :2, :]
prompt_aco_toks[1] += 1024
prompt_aco_toks = torch.stack([prompt_aco_toks[0], prompt_aco_toks[1]], dim=1).flatten()

In [None]:
prompt_toks_dict = {
    'source_tokens': sem_toks.numpy(),
    'target_tokens': prompt_aco_toks.numpy()
}

In [None]:
def gen_new_prompt(model, source, target, source_tokens, prompt_dict, device):
    source_tokens = source_tokens + cfg.OFFSET[source]
    max_source_tokens = cfg.max_source_tokens//2

    prompt_source_tokens = prompt_toks_dict.get('source_tokens')
    prompt_target_tokens = prompt_toks_dict.get('target_tokens')
    print(f'Prompt source tokens: {prompt_source_tokens.shape}, prompt target tokens: {prompt_target_tokens.shape}')

    source_overlap = 64
    target_overlap = 0
    source_stride = max_source_tokens - source_overlap

    # Initialize as empty
    target_tokens = np.asarray([])

    print(
        f'Source, tokens shape: {source_tokens.shape}, overlap: {source_overlap}, stride: {source_stride}, max tokens: {max_source_tokens}'
    )

    for idx in range(0, len(source_tokens), source_stride):
        end_idx = idx + max_source_tokens
        source_cut = source_tokens[idx: end_idx]
        target_cut = target_tokens[-target_overlap:]

        input_tokens = np.hstack([
            source_cut,
            cfg.INFER_TOKEN[target],
            target_cut
        ])

        if idx == 0:
            input_tokens = np.hstack([
                prompt_source_tokens,
                source_cut,
                cfg.INFER_TOKEN[target],
                prompt_target_tokens,
                target_cut
            ])

        input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=device)[None, ...]

        print(f'Source tokens shape: {input_tokens.shape}, start idx: {idx}, end idx: {end_idx}')
        print(f'Target cut shape: {target_cut.shape}, overlap: {target_overlap}')
        
        with torch.no_grad():
            with ctx:
                new_target_tokens = model.generate(
                    input_tokens,
                    1024,
                    temperature=0.8,
                    top_k=100,
                    stop_token=cfg.STOP_TOKEN[target]
                ).detach().cpu().numpy()[0]
                print(f'Gen shape: {new_target_tokens.shape}')

        new_target_tokens = new_target_tokens[input_tokens.shape[-1]:]

        # Update the target overlap ratio, for x toks, we generate y toks
        num_source_new_toks = end_idx-idx
        if idx:
            num_source_new_toks = end_idx-idx-source_overlap
        target_overlap = source_overlap * new_target_tokens.shape[-1]/num_source_new_toks
        target_overlap = math.ceil(target_overlap)
        print(f'Source toks: {num_source_new_toks}, New target shape: {new_target_tokens.shape}, overlap: {target_overlap}')
        # Merge into existing target tokens
        target_tokens = np.hstack([target_tokens, new_target_tokens])
        print(f'Overall target shape: {target_tokens.shape}')

        print('\n')

    target_tokens = target_tokens - cfg.OFFSET[target]
    return target_tokens

In [None]:
acoustic_tokens = gen_new_prompt(
    model=ttslib.semantic_acoustic_model, 
    source_tokens=sem_toks,
    source=SEMANTIC,
    target=ACOUSTIC,
    prompt_dict=prompt_toks_dict,
    device='cuda:0'
)

print(acoustic_tokens.shape)

In [None]:
# wav = ttslib.acoustic_tokenizer.decode(torch.tensor(acoustic_tokens[:-1]))
wav = ttslib.acoustic_tokenizer.decode(torch.tensor(acoustic_tokens))
Audio(wav[0], rate=24000)

In [None]:
sem_toks.shape

In [None]:
acoustic_tokens = gen_new(
    model=ttslib.semantic_acoustic_model, 
    source_tokens=sem_toks,
    source=SEMANTIC,
    target=ACOUSTIC,
    device='cuda:0'
)

print(acoustic_tokens.shape)

In [None]:
wav = ttslib.acoustic_tokenizer.decode(torch.tensor(acoustic_tokens))
Audio(wav[0], rate=24000)