In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq, TrainerCallback
AUDIO_FILE_PATH = 'data/reconstructed_audio.wav'
from utils import reconstruct_codec, flat_codec

BASE_MODEL = "BEE-spoke-data/smol_llama-101M-GQA"
model_path = './results/2024-05-24_22:15-fixed_seperator_test_with_less_databased_on_BEE-spoke-data/smol_llama-101M-GQA/checkpoint-30000'
MAX_LEN = 512

# load model
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
tokenizer.padding_side = "left"

#tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
model = model.cuda()

In [None]:
import soundfile as sf
from snac import SNAC
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
snac = snac.cuda()

In [None]:
from datasets import load_dataset
from IPython.display import Audio

dataset = load_dataset("blanchon/snac_llm_parler_tts", split='train[0:100]' ) # to only get a subset for testing
dataset = dataset.train_test_split(test_size=0.3, seed=42)

sample = dataset["test"][0]
tokens = [int(t) for t in sample["snac24khz"].split(" ")]
codes = reconstruct_codec(tokens)
audio_hat = snac.decode(codes) 
sf.write(AUDIO_FILE_PATH, audio_hat.cpu().detach().numpy().squeeze(), 24000)
print(sample["text"])
Audio(AUDIO_FILE_PATH)

In [None]:
#without tensors
SEP = tokenizer("[audio]")["input_ids"][1:]


input_ids = tokenizer(sample["text"],  padding=False, truncation=True, max_length=512-len(SEP))["input_ids"]+SEP
#input_ids += [int(t) for t in sample["snac24khz"].split(" ")][:7] # try to enforce right voice

with torch.no_grad():
    outputs = model.generate(torch.tensor([input_ids]).to("cuda"), max_length=MAX_LEN, pad_token_id=tokenizer.eos_token_id, temperature=0)
    outputs = outputs[0][len(input_ids):]


codes = reconstruct_codec(outputs)
audio_hat = snac.decode(codes) 
sf.write(AUDIO_FILE_PATH, audio_hat.cpu().detach().numpy().squeeze(), 24000)
print(sample["text"])
Audio(AUDIO_FILE_PATH)