In [None]:
import os
import torch
import math
import numpy as np
from pathlib import Path
from IPython.display import Audio
import matplotlib.pyplot as plt
from huggingface_hub import snapshot_download

from audiotoken import AudioToken, Tokenizers

from tts.infer import AudioSemantic as VanillaAudioSemantic, generate
from tts.train import DataLoader
from omni.hfload import convert_to_hf
from common import Config as cfg
from common import ACOUSTIC, SEMANTIC, TEXT, ctx, cache_dir, device

In [None]:
model_dir = f'{cache_dir}/models/tts_en_xl_125m/'
semantic_text_model = convert_to_hf(path=f'{model_dir}/semantic_text/gpt_last.pt', device=device)
ttslib = VanillaAudioSemantic()

In [None]:
def replace_consecutive(arr):
    mask = np.concatenate(([True], arr[1:] != arr[:-1]))
    return arr[mask]

In [None]:
toks = np.load('prompts/jenny_short/tokens.npz')
semantic_tokens = toks['SEMANTIC']

In [None]:
semantic_tokens.shape

In [None]:
semantic_tokenizer = AudioToken(Tokenizers.semantic_s, device='cuda:0')

In [None]:
from functools import partial
from transformers import Wav2Vec2FeatureExtractor

def hubert_processor(audio, processor):
    return processor(
        audio,
        sampling_rate=16_000,
        return_tensors='pt'
    ).input_values[0]


processor = Wav2Vec2FeatureExtractor.from_pretrained('voidful/mhubert-base')
transform_func = partial(hubert_processor, processor=processor)

In [None]:
from tts.utils import read_audio_file

In [None]:
aud= read_audio_file('prompts/jenny_short/audio.wav', 16000)

In [None]:
aud = transform_func(aud)

In [None]:
aud.shape

In [None]:
def replace_consecutive(arr):
    mask = np.concatenate(([True], arr[1:] != arr[:-1]))
    return arr[mask]

In [None]:
source_tokens = semantic_tokenizer.encode(aud)
source_tokens = source_tokens.cpu().numpy()[0][0]
source_tokens = replace_consecutive(source_tokens)

In [None]:
source_tokens.shape

In [None]:
txt_toks = generate(
    model=semantic_text_model,
    source_tokens=semantic_tokens,
    source=SEMANTIC,
    target=TEXT,
    max_length=1024,
    max_source_tokens=768,
    temperature=0.8,
    top_k=100
)

In [None]:
txt_toks

In [None]:
from datalib.tokenlib import get_tokenizer

In [None]:
decoder = get_tokenizer(TEXT, 'cpu')

In [None]:
decoder.decode(txt_toks - cfg.OFFSET[TEXT])