In [None]:
import sys
sys.path.append('..')

import os
import math
import torch
import random
import numpy as np
from pathlib import Path
from torch.cuda import empty_cache
from IPython.display import Audio
import matplotlib.pyplot as plt

from configs.commons import Config as cfg
from configs.commons import DEVICE, CACHE_DIR, CTX
from configs.constants import *

from omni.hfload import convert_to_hf
from datalib.tokenlib import get_tokenizer

In [None]:
DEVICE = 'cuda:0'

In [None]:
omni_model = convert_to_hf(path=f'/home/.cache/indri/models/omni_tasks_large/gpt_2900.pt', device=DEVICE)
semantic_acoustic_model = convert_to_hf(path=f'/home/.cache/indri/romit/models/semantic_acoustic_tasks_small/gpt_4500.pt', device=DEVICE)

text_tokenizer = get_tokenizer(TEXT, device='cpu')
acoustic_tokenizer = get_tokenizer(ACOUSTIC, device=DEVICE)

In [None]:
for idx in range(cfg.VOCAB_SIZES[SEMANTIC]):
    text_tokenizer.tokenizer.add_tokens(f'[sem_{idx}]')

for idx in range(cfg.VOCAB_SIZES[ACOUSTIC]):
    text_tokenizer.tokenizer.add_tokens(f'[aco_{idx}]')

for tok in list(cfg.MODALITY_TOKENS.values()) + list(cfg.TASK_TOKENS.values()) + [cfg.STOP_TOKEN]:
    print('Adding token: ', tok)
    text_tokenizer.tokenizer.add_tokens(tok)

In [None]:
convert_token = text_tokenizer.encode(cfg.TASK_TOKENS[CONVERT])
continue_token = text_tokenizer.encode(cfg.TASK_TOKENS[CONTINUE])
stop_token = text_tokenizer.encode(cfg.STOP_TOKEN)
semantic_modality_token = text_tokenizer.encode(cfg.MODALITY_TOKENS[SEMANTIC])
acoustic_modality_token = text_tokenizer.encode(cfg.MODALITY_TOKENS[ACOUSTIC])
text_modality_token = text_totemp_sem_tokskenizer.encode(cfg.MODALITY_TOKENS[TEXT])
speaker_id = text_tokenizer.encode("[spkr_unk]")

omni_model.generation_config.eos_token_id = stop_token
semantic_acoustic_model.generation_config.eos_token_id = stop_token

In [None]:
random_txt = "once upon a time there was a girl named emily"
txt_toks = np.array(text_tokenizer.encode(random_txt))

In [None]:
input_tokens = np.hstack([
    text_modality_token,
    txt_toks,
    convert_token,
    semantic_modality_token,
    speaker_id,
])
input_tokens = (torch.tensor(input_tokens, dtype=torch.long, device=DEVICE)[None, ...])
print(f'Text tokens: {input_tokens.shape}')

In [None]:
input_tokens = np.hstack([
    semantic_modality_token,
    speaker_id,
    temp_sem_toks,
    convert_token,
    text_modality_token,
])
input_tokens = (torch.tensor(input_tokens, dtype=torch.long, device=DEVICE)[None, ...])
print(f'Text tokens: {input_tokens.shape}')

In [None]:
text_tokenizer.decode(input_tokens[0])

In [None]:
with CTX:
    semantic_tokens = omni_model.generate(
        input_tokens,
        max_length=1024,
        temperature=0.7,
        top_k=100,
        do_sample=True
    )
    semantic_tokens = semantic_tokens.detach().cpu().numpy()[0]
    semantic_tokens = semantic_tokens[input_tokens.shape[-1]:]
    print(semantic_tokens.shape)

In [None]:
semantic_tokens

In [None]:
text_tokenizer.decode(semantic_tokens)

In [None]:
end_idx = np.where(semantic_tokens == stop_token)[0][0]
semantic_tokens = semantic_tokens[0:end_idx]

In [None]:
semantic_tokens = np.hstack([
    semantic_modality_token,
    speaker_id,
    semantic_tokens,
    convert_token,
    acoustic_modality_token,
    speaker_id,
])
semantic_tokens = (torch.tensor(semantic_tokens, dtype=torch.long, device=DEVICE)[None, ...])
print(f'Semantic tokens: {semantic_tokens.shape}')

In [None]:
with CTX:
    acoustic_tokens = semantic_acoustic_model.generate(
        semantic_tokens,
        max_length=3072,
        temperature=0.9,
        top_k=100,
        do_sample=True
    )

    acoustic_tokens = acoustic_tokens.detach().cpu().numpy()[0]
    acoustic_tokens = acoustic_tokens[semantic_tokens.shape[-1]:]
    print(acoustic_tokens.shape)

end_idx = np.where(acoustic_tokens == stop_token)[0][0]
acoustic_tokens = acoustic_tokens[0:end_idx]
acoustic_tokens = acoustic_tokens - cfg.OFFSET[ACOUSTIC]

if len(acoustic_tokens) % 2 == 1:
    acoustic_tokens = acoustic_tokens[:-1]

print(f'Acoustic tokens: {acoustic_tokens.shape}')

In [None]:
wav = acoustic_tokenizer.decode(torch.tensor(acoustic_tokens))
wav = wav[0].cpu().numpy()
Audio(wav, rate=24000)

In [None]:
speaker_id = text_tokenizer.encode("[spkr_jenny]")

In [None]:
## Testing with custom tokens
prompt = np.load('../prompts/jenny_short/tokens.npz')
temp_sem_toks = prompt['SEMANTIC']
temp_sem_toks += cfg.OFFSET[SEMANTIC]
# temp_sem_toks = np.hstack([
#     semantic_modality_token,
#     speaker_id,
#     temp_sem_toks,
#     convert_token,
#     acoustic_modality_token,
#     speaker_id,
# ])
# temp_sem_toks = (torch.tensor(temp_sem_toks, dtype=torch.long, device=DEVICE)[None, ...])
print(temp_sem_toks.shape)

# speaker_id = text_tokenizer.encode("[spkr_jenny]")

In [None]:
text_tokenizer.decode(temp_sem_toks[0])

In [None]:
def replace_consecutive(arr):
    mask = np.concatenate(([True], arr[1:] != arr[:-1]))
    return arr[mask]