In [None]:
import sys
sys.path.append('..')

In [None]:
import os
import math
import torch
import random
import numpy as np
from pathlib import Path
from torch.cuda import empty_cache
from IPython.display import Audio
import matplotlib.pyplot as plt

from common import Config as cfg
from common import ACOUSTIC, SEMANTIC, TEXT, DEVICE, ctx, cache_dir

from omni.hfload import convert_to_hf
from datalib.tokenlib import get_tokenizer

In [None]:
def decorate(tokens, type):
    tokens = tokens + cfg.OFFSET[type]
    tokens = np.hstack([cfg.INFER_TOKEN[type],
                        tokens,
                        cfg.STOP_TOKEN[type]])
    return tokens

def extract_new_tokens(y, target):
    start_idx = np.where(y == cfg.INFER_TOKEN[target])[0]
    end_idx = np.where(y == cfg.STOP_TOKEN[target])[0]
    if end_idx.any():
        y = y[start_idx[0] + 1: end_idx[0]]
    else:
        y = y[start_idx[0] + 1:]

    return y

In [None]:
DEVICE = 'cuda:1'

In [None]:
local_dir = f'{cache_dir}/models/omni_774m_tinystories'
omni_model = convert_to_hf(path=f'{local_dir}/omni.pt', device=DEVICE)
semantic_acoustic_model = convert_to_hf(path=f'/home/.cache/indri/romit/models/gpt_small.pt', device=DEVICE)

text_tokenizer = get_tokenizer(TEXT, device='cpu')
acoustic_tokenizer = get_tokenizer(ACOUSTIC, device=DEVICE)

omni_model.generation_config.eos_token_id = cfg.STOP_TOKEN[SEMANTIC]

In [None]:
random_txt = "once upon a time <comma> in a cozy little house at the edge of a forest <comma> lived a curious mouse named pip <period>"

In [None]:
txt_toks = np.array(text_tokenizer.encode(random_txt))
txt_toks = decorate(txt_toks, type=TEXT)

In [None]:
alltokens = np.hstack([txt_toks, [cfg.INFER_TOKEN[SEMANTIC]]])
input_tokens = (torch.tensor(alltokens, dtype=torch.long, device=DEVICE)[None, ...])

In [None]:
with ctx:
    sem_toks = omni_model.generate(
        input_tokens,
        max_length=1024,
        temperature=0.7,
        top_k=100,
        do_sample=True
    )
    sem_toks = sem_toks.detach().cpu().numpy()[0]
    sem_toks = sem_toks[len(alltokens):]

In [None]:
end_idx = np.where(sem_toks == cfg.STOP_TOKEN[SEMANTIC])[0][0]
sem_toks = sem_toks[0:end_idx]
sem_toks = sem_toks[:768]

In [None]:
sem_toks = np.hstack([sem_toks, cfg.INFER_TOKEN[ACOUSTIC]])
sem_toks = (torch.tensor(sem_toks, dtype=torch.long,device=DEVICE)[None, ...])

In [None]:
with ctx:
    acoustic_tokens = semantic_acoustic_model.generate(
        sem_toks,
        max_length=cfg.BLOCK_SIZE[SEMANTIC],
        temperature=0.95,
        top_k=100,
        do_sample=True
    )

    acoustic_tokens = acoustic_tokens.detach().cpu().numpy()[0]

In [None]:
acoustic_tokens = extract_new_tokens(acoustic_tokens, target=ACOUSTIC)
acoustic_tokens = acoustic_tokens - cfg.OFFSET[ACOUSTIC]
if len(acoustic_tokens) % 2 == 1:
    acoustic_tokens = acoustic_tokens[:-1]

In [None]:
sem_toks.shape, acoustic_tokens.shape

In [None]:
wav = acoustic_tokenizer.decode(torch.tensor(acoustic_tokens))
wav = wav[0].cpu().numpy()

In [None]:
Audio(wav, rate=24000)