In [4]:
import torch
import torchaudio
from llmspeech.model import GPT
from llmspeech.text import tokenize, detokenize
from llmspeech.audio import unflatten_and_remove_offsets, CODEC_SAMPLE_RATE
from llmspeech.utils import count_parameters
import IPython.display as ipd
import time
from snac import SNAC
from torch import Tensor
import torch.nn.functional as F

In [5]:
# torch._inductor.config.coordinate_descent_tuning = True
# torch._inductor.config.triton.unique_kernel_names = True
# torch._inductor.config.fx_graph_cache = True  # Experimental feature to reduce compilation times, will be on by default in future

In [6]:
device = "cuda"

In [7]:
# load from a local path
# checkpoint_path = "small/gpt-055000.pt"
# checkpoint_path = "medium/gpt-020000.pt"
# step = checkpoint_path.split("-")[-1].split(".")[0]
# model = GPT.from_pretrained(checkpoint_path).eval().to(device)

In [8]:
name = "gpt-small-055000.pt"
step = name.split("-")[-1].split(".")[0]
model = GPT.from_huggingface("gpt-small-055000.pt").eval().to(device)

In [9]:
codec = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)

In [10]:
n_parameters = int(count_parameters(model) / 1e6)
print(f"{type(model).__name__} {n_parameters}M parameters")

GPT 104M parameters


In [11]:
config = model.config

In [12]:
texts = [
    "The quick brown fox jumps over the lazy dog.",
    # "Does this work?",
    # "It was the best of times, it was the worst of times, it was the age of wisdom, it was the age of foolishness, it was the epoch of belief, it was the epoch of incredulity, it was the season of light, it was the season of darkness, it was the spring of hope, it was the winter of despair.",
    # "It was a bright cold day in April, and the clocks were striking thirteen.",
]

In [13]:
with torch.inference_mode():
    for text in texts:
        input_ids = [config.bos_token_id] + tokenize(text) + [config.boa_token_id]
        input_ids = torch.tensor(input_ids, device=device).unsqueeze(0)

        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            output_ids = model.generate(
                input_ids, max_new_tokens=20 * 84, temperature=1.0, top_k=64
            )
        codes = unflatten_and_remove_offsets(
            output_ids[0],
            n_text_tokens=config.n_text_tokens,
            codebook_size=config.codebook_size,
        )
        waveform = codec.decode(codes)
        print(text)
        waveform = waveform.squeeze(0).cpu()
        ipd.display(ipd.Audio(waveform, rate=CODEC_SAMPLE_RATE))
        torchaudio.save(f"examples/gpt-{step}-{n_parameters}M-{int(time.time())}.wav", waveform, CODEC_SAMPLE_RATE)

The quick brown fox jumps over the lazy dog.


In [14]:
output_ids.shape

torch.Size([1, 406])

In [15]:
def prefill(model: GPT, input_ids: Tensor, input_pos: Tensor, temperature=1.0, top_k=0):
    logits = model(input_ids=input_ids, input_pos=input_pos, num_last_tokens=1)
    logits = logits[:, -1]
    logits = logits / temperature

    if top_k > 1:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        logits[logits < v[:, [-1]]] = -float("inf")

    probs = F.softmax(logits, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token

In [16]:
def decode(model: GPT, cur_token: Tensor, input_pos: Tensor, temperature=1.0, top_k=1):
    logits = model(input_ids=cur_token, input_pos=input_pos, num_last_tokens=1)
    logits = logits[:, -1]

    # next_token = sample(logits, temperature=temperature, top_k=top_k)
    logits = logits / temperature

    if top_k > 1:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        logits[logits < v[:, [-1]]] = -float("inf")

    probs = F.softmax(logits, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token

In [17]:
@torch.inference_mode()
def generate(
    model: GPT,
    input_ids: Tensor,
    temperature: float = 1.0,
    top_k: int = 0,
    max_new_tokens: int = 700,
    do_masking: bool = False,
):
    global prefill, decode

    # TODO still getting: skipping cudagraphs due to ['incompatible ops']
    prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
    decode = torch.compile(decode, mode="reduce-overhead", fullgraph=True)

    B = input_ids.size(0)
    device = input_ids.device

    step = 0

    n_text_tokens = model.config.n_text_tokens
    n_quantizers = model.config.n_quantizers
    codebook_size = model.config.codebook_size

    period = 7
    rem_to_level = {0: 0, 1: 1, 2: 2, 3: 2, 4: 1, 5: 2, 6: 2}

    T = input_ids.size(-1)
    T_new = T + max_new_tokens

    empty = torch.empty(T_new, dtype=input_ids.dtype, device=device)
    empty[:T] = input_ids
    seq = empty

    batch_size = 1
    model.decoder.allocate_inference_cache(batch_size, device)

    input_pos = torch.arange(0, T, device=device)

    next_token = prefill(
        model, input_ids, input_pos, temperature=temperature, top_k=top_k
    )

    seq[T] = next_token

    input_pos = torch.tensor([T], device=device, dtype=torch.int32)
    new_tokens = []
    cur_token = next_token.view(1, -1)
    new_tokens.append(next_token.clone())

    while step < max_new_tokens - 1:
        with torch.backends.cuda.sdp_kernel(
            enable_flash=False, enable_mem_efficient=False, enable_math=True
        ):  # Actually better for Inductor to codegen attention here
            next_token = decode(
                model, cur_token, input_pos, temperature=temperature, top_k=top_k
            )

        new_tokens.append(next_token.clone())
        cur_token = next_token.view(1, -1)
        input_pos += 1
        step += 1

    new_tokens = torch.cat(new_tokens, dim=-1)

    print(f"{new_tokens.size()=}")

    stop_idx = (new_tokens == model.config.eos_token_id).nonzero()
    print(stop_idx)

    if stop_idx.numel() > 0:
        stop_idx = stop_idx[0, -1].item()
        stop_reason = "eos"
    else:
        stop_idx = new_tokens.size(-1)
        stop_reason = "max_tokens"

    print(f"{stop_idx=} {stop_reason=}")
    return new_tokens[:, :stop_idx]

In [18]:
with torch.inference_mode():
    for text in texts:
        input_ids = [config.bos_token_id] + tokenize(text) + [config.boa_token_id]
        input_ids = torch.tensor(input_ids, device=device).unsqueeze(0)

        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            output_ids = generate(
                model, input_ids, max_new_tokens=10 * 84, temperature=1.0, top_k=64
            )

        print(f"{output_ids.shape=}")

        codes = unflatten_and_remove_offsets(
            output_ids[0],
            n_text_tokens=config.n_text_tokens,
            codebook_size=config.codebook_size,
        )
        waveform = codec.decode(codes)
        print(text)
        waveform = waveform.squeeze(0).cpu()
        ipd.display(ipd.Audio(waveform, rate=CODEC_SAMPLE_RATE))

skipping cudagraphs due to ['incompatible ops']


new_tokens.size()=torch.Size([1, 840])
tensor([[  0, 364],
        [  0, 553],
        [  0, 686]], device='cuda:0')
stop_idx=364 stop_reason='eos'
output_ids.shape=torch.Size([1, 364])
The quick brown fox jumps over the lazy dog.


In [19]:
# model.decoder.blocks[0].attn.kv_cache.k_cache