In [1]:
import torch
import torchaudio
from llmspeech.model import GPT
from llmspeech.text import tokenize, detokenize
from llmspeech.audio import unflatten_and_remove_offsets, CODEC_SAMPLE_RATE
from llmspeech.utils import count_parameters
import IPython.display as ipd
import time
from snac import SNAC
from torch import Tensor
import torch.nn.functional as F
from contextlib import nullcontext
from llmspeech.generation import generate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

compile = True
# precision = torch.bfloat16
precision = torch.float32
device = "cuda"

ctx = (
    torch.cuda.amp.autocast(dtype=torch.bfloat16)
    if precision == torch.float32
    else nullcontext()
)

rope_theta = 10000.0  # twice the default


# name = "gpt-small-055000.pt"
# step = name.split("-")[-1].split(".")[0]
# model = (
#     GPT.from_huggingface(name, rope_theta=rope_theta)
#     .eval()
#     .to(dtype=precision, device=device)
# )
name = "./runs/850e3bv2/gpt-000200.pt"
name = "./runs/ofmpactw/gpt-000500.pt"
name = "./runs/bw6215fn/gpt-001000.pt"
name = "./runs/f7zjo4js/gpt-001000.pt"  # with style tags

step = name.split("-")[-1].split(".")[0]
model = (
    GPT.from_pretrained(name, rope_theta=rope_theta)
    .eval()
    .to(dtype=precision, device=device)
)
config = model.config

codec = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)


In [None]:

_1984 = "It was a bright cold day in April, and the clocks were striking thirteen. Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind, slipped quickly through the glass doors of Victory Mansions, though not quickly enough to prevent a swirl of gritty dust from entering along with him."
_1984_2 = "The hallway smelt of boiled cabbage and old rag mats. At one end of it a coloured poster, too large for indoor display, had been tacked to the wall. It depicted simply an enormous face, more than a metre wide: the face of a man of about forty-five, with a heavy black moustache and ruggedly handsome features."

harry_potter = "Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense."

texts = [
    "Does this work?",
    _1984,
    _1984_2,
    harry_potter,
]

if config.with_style_prompts:
    print(f"Using style prompts")
    texts = [
        f"[{style}]{text}"
        for text in texts
        for style in ["default", "narration", "whisper"]
    ]


waveforms = []

temperature, top_k = 1.0, 64
max_new_tokens = 20 * 84

with torch.inference_mode():
    for i, text in enumerate(texts):
        print(f"{len(text)=}")
        t1 = time.perf_counter()
        input_ids = [config.bos_token_id] + tokenize(text) + [config.boa_token_id]
        input_ids = torch.tensor(input_ids, device=device).unsqueeze(0)

        with ctx:
            output_ids = generate(
                model,
                input_ids,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_k=top_k,
                compile=compile,
            )

        torch.cuda.synchronize()

        t2 = time.perf_counter()

        latency = t2 - t1

        tokens_per_second = output_ids.size(-1) / latency

        print(
            f"{i=} {compile=} {text=} {output_ids.size(-1)} generated tokens {tokens_per_second=:.2f} {latency=:.2f}"
        )
        step = 7
        T_out = output_ids.size(-1)
        rem = T_out % step
        if rem > 0:
            print(f"{rem=}")
            output_ids = output_ids[:, :-rem]

        codes = unflatten_and_remove_offsets(
            output_ids[0],
            n_text_tokens=config.n_text_tokens,
            codebook_size=config.codebook_size,
        )
        waveform = codec.decode(codes)
        waveform = waveform.squeeze(0).cpu()
        waveforms.append(waveform)
        torchaudio.save(
            f"{temperature=}-{top_k=}-{int(time.time())}.wav",
            waveform,
            CODEC_SAMPLE_RATE,
        )
