In [None]:
import os
import torch
import math
import numpy as np
from pathlib import Path
from IPython.display import Audio
import matplotlib.pyplot as plt

from audiotoken import AudioToken, Tokenizers

from tts.long_infer import AudioSemantic
from tts.long_infer import generate as aco_generate
from common import Config as cfg
from common import ACOUSTIC, SEMANTIC, TEXT, ctx

In [None]:
ttslib = AudioSemantic()
acoustic_tokenizer = AudioToken(Tokenizers.acoustic, device='cuda:0')
semantic_tokenizer = AudioToken(Tokenizers.semantic_s, device='cuda:0')

In [None]:
def replace_consecutive(arr):
    mask = np.concatenate(([True], arr[1:] != arr[:-1]))
    return arr[mask]

Ready the prompts

In [None]:
prompt_aco_toks = acoustic_tokenizer.encode(Path('prompts/female_prompt_short.wav'))
prompt_sem_toks = semantic_tokenizer.encode(Path('prompts/female_prompt_short.wav'))

prompt_sem_toks = replace_consecutive(prompt_sem_toks[0][0])
prompt_aco_toks.shape, prompt_sem_toks.shape

In [None]:
flat_aco_toks = prompt_aco_toks[0, :2, :].clone()
flat_aco_toks[1] += 1024
flat_aco_toks = torch.stack([flat_aco_toks[0], flat_aco_toks[1]], dim=1).flatten()

prompt_toks_dict = {
    'source_tokens': prompt_sem_toks.numpy(),
    'target_tokens': flat_aco_toks.numpy()
}

In [None]:
auds = acoustic_tokenizer.decode(prompt_aco_toks)
Audio(auds[0], rate=24000)

In [None]:
aco_gen_toks = aco_generate(
    model=ttslib.semantic_acoustic_model, 
    source_tokens=prompt_sem_toks.numpy(),
    source=SEMANTIC,
    target=ACOUSTIC
)
aco_gen_toks.shape

In [None]:
# Generation from original semantic tokens
auds = ttslib.semantic_to_audio(prompt_sem_toks.numpy())
Audio(auds[0], rate=24000)

In [None]:
# Generation from intermediatery 2 codebook
auds = ttslib.acoustic_tokenizer.decode(torch.tensor(aco_gen_toks))
Audio(auds[0], rate=24000)

Long text, testing text to semantic model

In [None]:
txt1 = "the breeze was gentle <comma> rustling the leaves of the trees as birds chirped softly in the distance <period>"
txt2 = "it was a perfect evening to take a leisurely stroll <comma> letting the calmness of nature wash over you <period>"
txt3 = "every step on the gravel path felt like a soothing rhythm <comma> matching the tranquility of the surroundings <period>"
txt4 = "as the sky shifted from orange to deep purple <comma> the first stars began to appear <comma> twinkling like tiny diamonds in the vastness above <period>"

In [None]:
txt1 = "our adventure began in paris <period>"
txt2 = "the eiffel tower amazed us <period>"
txt3 = "we enjoyed cafes and croissants <period>"
txt4 = "the louvres art was stunning <period>"
txt5 = "we ended in nice by the sea <period>"

In [None]:
sem_toks = ttslib.text_to_semantic(' '.join([txt1, txt2, txt3, txt4, txt5]))
print(sem_toks.shape, np.unique(sem_toks).shape[0])

In [None]:
sem_toks_diff = []
for t in [txt1, txt2, txt3, txt4, txt5]:
    s = ttslib.text_to_semantic(t)
    sem_toks_diff.extend(s)
    print(s.shape, replace_consecutive(s).shape)

sem_toks_diff = np.array(sem_toks_diff)
sem_toks_diff.shape

In [None]:
plt.hist(sem_toks)
plt.hist(sem_toks_diff, alpha=0.5)

In [None]:
auds = []

for i in range(5):
    try:
        aud = ttslib.semantic_to_audio(sem_toks_diff[150:300])
        print(aud.shape)
        auds.append(aud)
    except Exception as err:
        print(err)


for aud in auds:
    display(Audio(aud[0], rate=24000))

In [None]:
semantic_tokens = gen_new_prompt(
    model=ttslib.text_semantic_model, 
    source_tokens=np.array(ttslib.text_tokenizer.encode(' '.join([txt1, txt2, txt3, txt4, txt5]))),
    source=TEXT,
    target=SEMANTIC,
    prompt_dict=prompt_toks_dict,
    device='cuda:0'
)

print(acoustic_tokens.shape)

Legacy gen

In [None]:
auds = []

for i in range(10):
    try:
        aud = ttslib.semantic_to_audio_long(sem_toks)
        print(aud.shape)
        auds.append(aud)
    except:
        continue

In [None]:
for aud in auds:
    display(Audio(aud[0], rate=24000))

New gen (with conditioning and prompting)

In [None]:
def gen_new(model, source, target, source_tokens, device):
    source_tokens = source_tokens + cfg.OFFSET[source]
    max_source_tokens = cfg.max_source_tokens//2

    source_overlap = 64
    target_overlap = 0
    source_stride = max_source_tokens - source_overlap

    # Initialize as empty
    target_tokens = np.asarray([])

    print(
        f'Source, tokens shape: {source_tokens.shape}, overlap: {source_overlap}, stride: {source_stride}, max tokens: {max_source_tokens}'
    )

    for idx in range(0, len(source_tokens), source_stride):
        end_idx = idx + max_source_tokens
        source_cut = source_tokens[idx: end_idx]
        target_cut = target_tokens[-target_overlap:]

        input_tokens = np.hstack([
            source_cut,
            cfg.INFER_TOKEN[target],
            target_cut
        ])

        input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=device)[None, ...]

        print(f'Source tokens shape: {input_tokens.shape}, start idx: {idx}, end idx: {end_idx}')
        print(f'Target cut shape: {target_cut.shape}, overlap: {target_overlap}')
        
        with torch.no_grad():
            with ctx:
                new_target_tokens = model.generate(
                    input_tokens,
                    1024,
                    temperature=0.8,
                    top_k=100,
                    stop_token=cfg.STOP_TOKEN[target]
                ).detach().cpu().numpy()[0]
                print(f'Gen shape: {new_target_tokens.shape}')

        new_target_tokens = new_target_tokens[input_tokens.shape[-1]:]

        # Update the target overlap ratio, for x toks, we generate y toks
        num_source_new_toks = end_idx-idx
        if idx:
            num_source_new_toks = end_idx-idx-source_overlap
        target_overlap = source_overlap * new_target_tokens.shape[-1]/num_source_new_toks
        target_overlap = math.ceil(target_overlap)
        target_overlap = target_overlap + 1 if target_overlap%2 != 0 else target_overlap
        print(f'Source toks: {num_source_new_toks}, New target shape: {new_target_tokens.shape}, overlap: {target_overlap}')
        # Merge into existing target tokens
        target_tokens = np.hstack([target_tokens, new_target_tokens])
        print(f'Overall target shape: {target_tokens.shape}')

        print('\n')

        if end_idx > source_tokens.shape[-1]:
            break

    target_tokens = target_tokens - cfg.OFFSET[target]
    return target_tokens

In [None]:
def gen_new_prompt(model, source, target, source_tokens, prompt_dict, device, source_overlap=64):
    source_tokens = source_tokens + cfg.OFFSET[source]
    max_source_tokens = cfg.max_source_tokens//2

    prompt_source_tokens = prompt_toks_dict.get('source_tokens') + cfg.OFFSET[source]
    prompt_target_tokens = prompt_toks_dict.get('target_tokens') + cfg.OFFSET[target]

    print(f'Prompt source tokens: {prompt_source_tokens.shape}, prompt target tokens: {prompt_target_tokens.shape}')

    source_overlap = source_overlap
    target_overlap = 0
    source_stride = max_source_tokens - source_overlap

    # Initialize as empty
    target_tokens = np.asarray([])

    print(
        f'Source tokens shape: {source_tokens.shape}, Overlap: {source_overlap}, stride: {source_stride}, max tokens: {max_source_tokens}\n'
    )

    for idx in range(0, len(source_tokens), source_stride):
        end_idx = idx + max_source_tokens
        source_cut = source_tokens[idx: end_idx]
        target_cut = target_tokens[-target_overlap:]

        input_tokens = np.hstack([
            source_cut,
            cfg.INFER_TOKEN[target],
            target_cut
        ])

        if idx == 0:
            input_tokens = np.hstack([
                prompt_source_tokens,
                source_cut,
                cfg.INFER_TOKEN[target],
                prompt_target_tokens
            ])

        input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=device)[None, ...]


        print(f'{idx}: Target cut shape: {target_cut.shape}, overlap: {target_overlap}')
        print(f'{idx}: Source tokens shape: {input_tokens.shape}, start idx: {idx}, end idx: {end_idx}')
        
        with torch.no_grad():
            with ctx:
                new_target_tokens = model.generate(
                    input_tokens,
                    1024,
                    temperature=0.8,
                    top_k=100,
                    stop_token=cfg.STOP_TOKEN[target]
                ).detach().cpu().numpy()[0]
                print(f'{idx}: Total gen shape: {new_target_tokens.shape}')

        # Only take newly generated tokens
        new_target_tokens = new_target_tokens[input_tokens.shape[-1]:]

        if new_target_tokens.shape[-1] % 2 != 0:
            print('breaking here')
            return new_target_tokens

        # Update the target overlap ratio, for x toks, we generate y toks
        num_source_new_toks = end_idx-idx
        if idx:
            num_source_new_toks -= source_overlap
        target_overlap = source_overlap * new_target_tokens.shape[-1]/num_source_new_toks
        target_overlap = math.ceil(target_overlap)
        target_overlap = target_overlap + 1 if target_overlap%2 != 0 else target_overlap

        print(f'{idx}: X toks: {num_source_new_toks}, Y toks: {new_target_tokens.shape}, overlap: {target_overlap}')
        # Merge into existing target tokens
        target_tokens = np.hstack([target_tokens, new_target_tokens])
        print(f'{idx}: Overall target shape is now: {target_tokens.shape}')

        print('\n')

        if end_idx > source_tokens.shape[-1]:
            break

    target_tokens = target_tokens - cfg.OFFSET[target]
    return target_tokens

In [None]:
acoustic_tokens = gen_new_prompt(
    model=ttslib.semantic_acoustic_model, 
    source_tokens=sem_toks_diff,
    source=SEMANTIC,
    target=ACOUSTIC,
    prompt_dict=prompt_toks_dict,
    device='cuda:0'
)

print(acoustic_tokens.shape)

In [None]:
wav = ttslib.acoustic_tokenizer.decode(torch.tensor(acoustic_tokens))
Audio(wav[0], rate=24000)

In [None]:
# wav = ttslib.acoustic_tokenizer.decode(torch.tensor(acoustic_tokens[:-1]))
wav = ttslib.acoustic_tokenizer.decode(torch.tensor(acoustic_tokens))
Audio(wav[0], rate=24000)

In [None]:
acoustic_tokens = gen_new(
    model=ttslib.semantic_acoustic_model, 
    source_tokens=sem_toks_diff,
    source=SEMANTIC,
    target=ACOUSTIC,
    device='cuda:0'
)

print(acoustic_tokens.shape)

In [None]:
wav = ttslib.acoustic_tokenizer.decode(torch.tensor(acoustic_tokens))
Audio(wav[0], rate=24000)