In [None]:
import sys
sys.path.append('..')

import os
import math
import torch
import random
import numpy as np
from pathlib import Path
from torch.cuda import empty_cache
from IPython.display import Audio
import matplotlib.pyplot as plt

from configs.commons import Config as cfg
from configs.commons import DEVICE, CACHE_DIR, CTX
from configs.constants import *

from omni.hfload import convert_to_hf
from datalib.tokenlib import get_tokenizer
from omni.train_omni_instruct import DataLoader

In [None]:
DEVICE = 'cuda:0'

In [None]:
omni_model = convert_to_hf(path=f'/home/.cache/indri/models/omni_tasks_large_full_sprk/gpt_13000.pt', device=DEVICE)
semantic_acoustic_model = convert_to_hf(path=f'/home/.cache/indri/models/semantic_acoustic_tasks_spkr/gpt_26500.pt', device=DEVICE)

text_tokenizer = get_tokenizer(TEXT, device='cpu')
acoustic_tokenizer = get_tokenizer(ACOUSTIC, device=DEVICE)

In [None]:
dl = DataLoader(
    interleaved_dirs=[],
    datasets_dirs=[],
    speaker_files=[Path('../allowed_speakers.jsonl').resolve()]
)

text_tokenizer = dl.text_tokenizer

In [None]:
omni_model.generation_config.eos_token_id = dl.stop_token
semantic_acoustic_model.generation_config.eos_token_id = dl.stop_token

In [None]:
speaker_id = text_tokenizer.encode("[spkr_unk]")
acoustic_modality_token = text_tokenizer.encode(cfg.MODALITY_TOKENS[ACOUSTIC])

In [None]:
random_txt = "once upon a time there was a girl named emily"
txt_toks = np.array(text_tokenizer.encode(random_txt))

In [None]:
input_tokens = np.hstack([
    dl.text_modality_token,
    txt_toks,
    dl.convert_token,
    dl.semantic_modality_token,
    speaker_id,
])
input_tokens = (torch.tensor(input_tokens, dtype=torch.long, device=DEVICE)[None, ...])
print(f'Text tokens: {input_tokens.shape}')
text_tokenizer.decode(input_tokens[0])

In [None]:
input_tokens = np.hstack([
    dl.semantic_modality_token,
    speaker_id,
    temp_sem_toks,
    dl.convert_token,
    dl.text_modality_token,
])
input_tokens = (torch.tensor(input_tokens, dtype=torch.long, device=DEVICE)[None, ...])
print(f'Text tokens: {input_tokens.shape}')
text_tokenizer.decode(input_tokens[0])

In [None]:
with CTX:
    semantic_tokens = omni_model.generate(
        input_tokens,
        max_length=1024,
        temperature=0.8,
        top_k=100,
        do_sample=True
    )
    semantic_tokens = semantic_tokens.detach().cpu().numpy()[0]
    semantic_tokens = semantic_tokens[input_tokens.shape[-1]:]
    print(semantic_tokens.shape)

text_tokenizer.decode(semantic_tokens)

In [None]:
end_idx = np.where(semantic_tokens == dl.stop_token)[0][0]
semantic_tokens = semantic_tokens[0:end_idx]
print(semantic_tokens.shape)
text_tokenizer.decode(semantic_tokens)

In [None]:
semantic_tokens = np.hstack([
    dl.semantic_modality_token,
    speaker_id,
    semantic_tokens,
    dl.convert_token,
    acoustic_modality_token,
    speaker_id,
])
semantic_tokens = (torch.tensor(semantic_tokens, dtype=torch.long, device=DEVICE)[None, ...])
print(f'Semantic tokens: {semantic_tokens.shape}')
dl.text_tokenizer.decode(semantic_tokens[0])

In [None]:
with CTX:
    acoustic_tokens = semantic_acoustic_model.generate(
        semantic_tokens,
        max_length=3072,
        temperature=0.8,
        top_k=100,
        do_sample=True
    )

    acoustic_tokens = acoustic_tokens.detach().cpu().numpy()[0]
    acoustic_tokens = acoustic_tokens[semantic_tokens.shape[-1]:]
    print(acoustic_tokens.shape)

dl.text_tokenizer.decode(acoustic_tokens)

In [None]:
end_idx = np.where(acoustic_tokens == dl.stop_token)[0][0]
acoustic_tokens = acoustic_tokens[0:end_idx]
acoustic_tokens = acoustic_tokens - cfg.OFFSET[ACOUSTIC]

if len(acoustic_tokens) % 2 == 1:
    acoustic_tokens = acoustic_tokens[:-1]

print(f'Acoustic tokens: {acoustic_tokens.shape}')

In [None]:
wav = acoustic_tokenizer.decode(torch.tensor(acoustic_tokens))
wav = wav[0].cpu().numpy()
Audio(wav, rate=24000)

In [None]:
import torchaudio

In [None]:
torchaudio.save('jenny_7k_sem_aco.wav', torch.from_numpy(wav), sample_rate=24000)

In [None]:
## Testing with custom tokens
speaker_id = text_tokenizer.encode("[spkr_unk]")

prompt = np.load('../prompts/lj_female_long/tokens.npz')
temp_sem_toks = prompt['SEMANTIC'].astype(np.int64)
# temp_sem_toks = prompt
temp_sem_toks += cfg.OFFSET[SEMANTIC]
temp_sem_toks = np.hstack([
    dl.semantic_modality_token,
    speaker_id,
    temp_sem_toks,
    dl.convert_token,
    acoustic_modality_token,
    speaker_id,
])
temp_sem_toks = (torch.tensor(temp_sem_toks, dtype=torch.long, device=DEVICE)[None, ...])
print(temp_sem_toks.shape)

text_tokenizer.decode(temp_sem_toks[0])

In [None]:
txt  = np.load('/home/.cache/indri/data/gs_xl_en_tokens/tokens/text/YOU0000013586_S0000067.npy')

In [None]:
text_tokenizer.decode(txt)

In [None]:
def replace_consecutive(arr):
    mask = np.concatenate(([True], arr[1:] != arr[:-1]))
    return arr[mask]