In [1]:
import sys
sys.path.append('..')

import os
import math
import torch
import random
import numpy as np
from pathlib import Path
from torch.cuda import empty_cache
from IPython.display import Audio
import matplotlib.pyplot as plt

from configs.commons import Config as cfg
from configs.commons import DEVICE, CACHE_DIR, CTX
from configs.constants import *

from omni.hfload import convert_to_hf
from datalib.tokenlib import get_tokenizer

Cache directory at:  /home/.cache/indri
Gap tokens:  7


In [2]:
DEVICE = 'cuda:0'

In [73]:
omni_model = convert_to_hf(path=f'/home/.cache/indri/models/omni_tasks_large/gpt_6300.pt', device=DEVICE)
semantic_acoustic_model = convert_to_hf(path=f'/home/.cache/indri/romit/models/semantic_acoustic_tasks_small/gpt_6500.pt', device=DEVICE)

text_tokenizer = get_tokenizer(TEXT, device='cpu')
acoustic_tokenizer = get_tokenizer(ACOUSTIC, device=DEVICE)

loaded config GPTConfig(block_size=1024, vocab_size=53312, n_layer=36, n_head=20, n_embd=1280, dropout=0.0, bias=True)
loaded config GPTConfig(block_size=3072, vocab_size=53312, n_layer=12, n_head=12, n_embd=768, dropout=0.0, bias=False)
text vocab size 50257


In [74]:
convert_token = text_tokenizer.encode(cfg.TASK_TOKENS[CONVERT])
continue_token = text_tokenizer.encode(cfg.TASK_TOKENS[CONTINUE])
stop_token = text_tokenizer.encode('<|endoftext|>')
semantic_modality_token = text_tokenizer.encode(cfg.MODALITY_TOKENS[SEMANTIC])
acoustic_modality_token = text_tokenizer.encode(cfg.MODALITY_TOKENS[ACOUSTIC])
text_modality_token = text_tokenizer.encode(cfg.MODALITY_TOKENS[TEXT])
speaker_id = text_tokenizer.encode("[spkr_unk]")

omni_model.generation_config.eos_token_id = stop_token
semantic_acoustic_model.generation_config.eos_token_id = stop_token

In [121]:
random_txt = "once upon a time there was a girl named emily"
txt_toks = np.array(text_tokenizer.encode(random_txt))

In [139]:
input_tokens = np.hstack([
    text_modality_token,
    txt_toks,
    convert_token,
    speaker_id,
    semantic_modality_token
])
input_tokens = (torch.tensor(input_tokens, dtype=torch.long, device=DEVICE)[None, ...])
print(f'Text tokens: {input_tokens.shape}')

Text tokens: torch.Size([1, 29])


In [140]:
with CTX:
    semantic_tokens = omni_model.generate(
        input_tokens,
        max_length=1024,
        temperature=0.7,
        top_k=100,
        do_sample=True
    )
    semantic_tokens = semantic_tokens.detach().cpu().numpy()[0]
    semantic_tokens = semantic_tokens[input_tokens.shape[-1]:]
    print(semantic_tokens.shape)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


(118,)


In [141]:
semantic_tokens

array([50754, 50320, 51248, 50419, 50792, 50358, 50998, 50907, 51073,
       50582, 51058, 50423, 50806, 51079, 50346, 50451, 50921, 50289,
       50688, 50788, 50791, 51160, 50964, 50289, 50999, 50355, 50776,
       50283, 50461, 50537, 50925, 50424, 50912, 51021, 50771, 50553,
       50971, 50866, 50680, 50822, 51079, 50346, 50451, 50921, 50614,
       50508, 50669, 50517, 51230, 50545, 51053, 50904, 50760, 50377,
       50453, 50423, 50313, 50742, 50578, 51205, 50343, 50547, 51235,
       51090, 50641, 50506, 51226, 50480, 50704, 50983, 50301, 51159,
       51009, 51076, 50915, 50754, 50320, 50922, 51248, 51138, 50588,
       50430, 51202, 51044, 51192, 50358, 50998, 50898, 50381, 50594,
       50500, 50794, 51116, 50641, 51136, 50947, 50578, 51205, 50890,
       50343, 50576, 50794, 50673, 51093, 50934, 50510, 51048, 50634,
       50812, 50465, 50674, 51012, 50301, 51159, 51076, 51029, 50296,
       50256])

In [142]:
text_tokenizer.decode(semantic_tokens)

'<|endoftext|>'

In [134]:
np.where(semantic_tokens == stop_token)

(array([98]),)

In [135]:
# end_idx = np.where(semantic_tokens == stop_token)[0][0]
end_idx = 98
semantic_tokens = semantic_tokens[0:end_idx]

In [136]:
semantic_tokens = np.hstack([
    semantic_modality_token,
    semantic_tokens,
    convert_token,
    speaker_id,
    acoustic_modality_token
])
semantic_tokens = (torch.tensor(semantic_tokens, dtype=torch.long, device=DEVICE)[None, ...])
print(f'Semantic tokens: {semantic_tokens.shape}')

Semantic tokens: torch.Size([1, 117])


In [137]:
with CTX:
    acoustic_tokens = semantic_acoustic_model.generate(
        semantic_tokens,
        max_length=3072,
        temperature=0.9,
        top_k=100,
        do_sample=True
    )

    acoustic_tokens = acoustic_tokens.detach().cpu().numpy()[0]
    acoustic_tokens = acoustic_tokens[semantic_tokens.shape[-1]:]
    print(acoustic_tokens.shape)

end_idx = np.where(acoustic_tokens == stop_token)[0][0]
acoustic_tokens = acoustic_tokens[0:end_idx]
acoustic_tokens = acoustic_tokens - cfg.OFFSET[ACOUSTIC]

if len(acoustic_tokens) % 2 == 1:
    acoustic_tokens = acoustic_tokens[:-1]

print(f'Acoustic tokens: {acoustic_tokens.shape}')

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


(373,)
Acoustic tokens: (372,)


In [138]:
wav = acoustic_tokenizer.decode(torch.tensor(acoustic_tokens))
wav = wav[0].cpu().numpy()
Audio(wav, rate=24000)

decoding tokens shape: (2, 186)


100%|██████████| 1/1 [00:00<00:00,  2.56it/s]


In [96]:
speaker_id = text_tokenizer.encode("[spkr_jenny]")

In [61]:
## Testing with custom tokens
prompt = np.load('../prompts/lj_female_long/tokens.npz')
temp_sem_toks = prompt['SEMANTIC']
temp_sem_toks += cfg.OFFSET[SEMANTIC]
temp_sem_toks = np.hstack([
    semantic_modality_token,
    temp_sem_toks,
    convert_token,
    speaker_id,
    acoustic_modality_token
])
temp_sem_toks = (torch.tensor(temp_sem_toks, dtype=torch.long, device=DEVICE)[None, ...])
print(temp_sem_toks.shape)

# speaker_id = text_tokenizer.encode("[spkr_jenny]")

torch.Size([1, 284])


In [None]:
def replace_consecutive(arr):
    mask = np.concatenate(([True], arr[1:] != arr[:-1]))
    return arr[mask]