In [None]:
import os
import torch
import math
import numpy as np
from pathlib import Path
from IPython.display import Audio
import matplotlib.pyplot as plt
from huggingface_hub import snapshot_download

from audiotoken import AudioToken, Tokenizers

from tts.infer import AudioSemantic as VanillaAudioSemantic, load_model, generate
from common import Config as cfg
from common import ACOUSTIC, SEMANTIC, TEXT, ctx, cache_dir

In [None]:
model_dir = f'{cache_dir}/models/tts_en_xl_125m/'
semantic_text_model = load_model(
    path=f'{model_dir}/semantic_text/gpt_last.pt'
)
ttslib = VanillaAudioSemantic()

In [None]:
def replace_consecutive(arr):
    mask = np.concatenate(([True], arr[1:] != arr[:-1]))
    return arr[mask]

In [None]:
semantic_tokenizer = AudioToken(Tokenizers.semantic_s, device='cuda:0')

In [None]:
prompt_sem_toks = semantic_tokenizer.encode(Path('prompts/jenny_prompt.wav'))
prompt_sem_toks = replace_consecutive(prompt_sem_toks[0][0])

ts_prompt_toks_dict = {
    'source_tokens': prompt_sem_toks.numpy().astype(np.int64),
    'target_tokens': np.array(ttslib.text_tokenizer.encode('said meg impatiently <period>')).astype(np.int64),
}

print(ts_prompt_toks_dict['source_tokens'].shape, ts_prompt_toks_dict['target_tokens'].shape)

In [None]:
txt_toks = generate(
    model=semantic_text_model,
    source=SEMANTIC,
    target=TEXT,
    source_tokens=ts_prompt_toks_dict['source_tokens']
)

In [None]:
from datalib.tokenlib import get_tokenizer

In [None]:
decoder = get_tokenizer(TEXT, 'cpu')

In [None]:
decoder.decode(txt_toks)

In [None]:
from datasets import load_dataset
librispeech = load_dataset("openslr/librispeech_asr", trust_remote_code=True, split="Valid")