In [3]:
import torch
import torchaudio
from llmspeech.model import GPT
from llmspeech.text import tokenize, detokenize
from llmspeech.audio import unflatten_and_remove_offsets, CODEC_SAMPLE_RATE
from llmspeech.utils import count_parameters
import IPython.display as ipd
import time
from snac import SNAC

In [4]:
device = "cuda"

In [5]:
# load from a local path
# checkpoint_path = "small/gpt-055000.pt"
# checkpoint_path = "medium/gpt-020000.pt"
# step = checkpoint_path.split("-")[-1].split(".")[0]
# model = GPT.from_pretrained(checkpoint_path).eval().to(device)

In [6]:
name = "gpt-small-055000.pt"
step = name.split("-")[-1].split(".")[0]
model = GPT.from_huggingface("gpt-small-055000.pt").eval().to(device)

In [7]:
codec = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)

In [8]:
n_parameters = int(count_parameters(model) / 1e6)
print(f"{type(model).__name__} {n_parameters}M parameters")

GPT 104M parameters


In [9]:
config = model.config

In [10]:
texts = [
    "The quick brown fox jumps over the lazy dog.",
    "Does this work?",
    "It was the best of times, it was the worst of times, it was the age of wisdom, it was the age of foolishness, it was the epoch of belief, it was the epoch of incredulity, it was the season of light, it was the season of darkness, it was the spring of hope, it was the winter of despair.",
    "It was a bright cold day in April, and the clocks were striking thirteen.",
]

In [11]:
with torch.inference_mode():
    for text in texts:
        input_ids = [config.bos_token_id] + tokenize(text) + [config.boa_token_id]
        input_ids = torch.tensor(input_ids, device=device).unsqueeze(0)

        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            output_ids = model.generate(
                input_ids, max_new_tokens=20 * 84, temperature=1.0, top_k=64
            )
        codes = unflatten_and_remove_offsets(
            output_ids[0],
            n_text_tokens=config.n_text_tokens,
            codebook_size=config.codebook_size,
        )
        waveform = codec.decode(codes)
        print(text)
        waveform = waveform.squeeze(0).cpu()
        ipd.display(ipd.Audio(waveform, rate=CODEC_SAMPLE_RATE))
        torchaudio.save(f"examples/gpt-{step}-{n_parameters}M-{int(time.time())}.wav", waveform, CODEC_SAMPLE_RATE)

The quick brown fox jumps over the lazy dog.


Does this work?


It was the best of times, it was the worst of times, it was the age of wisdom, it was the age of foolishness, it was the epoch of belief, it was the epoch of incredulity, it was the season of light, it was the season of darkness, it was the spring of hope, it was the winter of despair.


It was a bright cold day in April, and the clocks were striking thirteen.
