In [15]:
import sentencepiece as spm
import torch
from ulm.module import LitGPT, GPTConfig
import json
import torch
from torchaudio.datasets import LIBRISPEECH
from IPython.display import Audio

In [16]:
block_size = 512

In [17]:
checkpoint_path = "/home/nicolvisser/workspace/ulm/lightning_logs/n_units-500-dp_lambda-0/checkpoints/epoch=26-step=8424.ckpt"
n_units = 500
dp_lambda = 0

In [18]:
checkpoint_path = "/home/nicolvisser/workspace/ulm/lightning_logs/n_units-500-dp_lambda-20/checkpoints/epoch=320-step=100000.ckpt"
n_units = 500
dp_lambda = 20

In [19]:
gslm = torch.hub.load(
    "nicolvisser/gslm-hubert-hifigan:master",
    "gslm",
    n_units=n_units,
    dp_lambda=dp_lambda,
    trust_repo=True,
)
gslm = gslm.to("cuda")

Using cache found in /home/nicolvisser/.cache/torch/hub/nicolvisser_gslm-hubert-hifigan_master
Using cache found in /home/nicolvisser/.cache/torch/hub/bshall_hubert_main
Using cache found in /home/nicolvisser/.cache/torch/hub/nicolvisser_hubert-kmeans_main
Using cache found in /home/nicolvisser/.cache/torch/hub/nicolvisser_duration-predictor_main
Using cache found in /home/nicolvisser/.cache/torch/hub/nicolvisser_acoustic-model_main
Using cache found in /home/nicolvisser/.cache/torch/hub/bshall_hifigan_main


In [20]:
sp = spm.SentencePieceProcessor(
        model_file=f"/home/nicolvisser/workspace/ulm/data/n_units-{n_units}-dp_lambda-{dp_lambda}/character_level.model"
    )

In [21]:
with open(
    f"/home/nicolvisser/workspace/ulm/data/n_units-{n_units}-dp_lambda-{dp_lambda}/config.json",
    "r",
) as f:
    config = GPTConfig(**json.load(f))

model = LitGPT.load_from_checkpoint(
    checkpoint_path,
    config=config,
)

number of parameters: 38.09M


In [31]:
def generate_unconditional(wav, sr, temperature=1.0, top_k=None):
    units = torch.randint(0, n_units, (1,)).cuda()
    unicode_chars = [chr(u + 0x4E00) for u in units]
    unicode_text = "".join(unicode_chars)
    ids = torch.tensor(sp.EncodeAsIds(unicode_text), dtype=torch.long).cuda()
    ids_continued = model.generate(
        ids.unsqueeze(0),
        block_size - len(units),
        temperature=temperature,
        top_k=top_k,
    )
    unicode_text_continued = sp.DecodeIds(ids_continued.squeeze(0).tolist())
    units_continued = [ord(c) - 0x4E00 for c in unicode_text_continued]
    units_continued = [u for u in units_continued if u >= 0 and u < n_units]
    units_continued = torch.tensor(units_continued, dtype=torch.long).cuda()
    wav_continued, sr = gslm.decode(units_continued)
    return Audio(wav_continued.cpu(), rate=sr)

In [28]:
def generate_continuation(wav, sr, temperature=1.0, top_k=None):
    wav = wav.to("cuda")
    units = gslm.encode(wav, sr)
    unicode_chars = [chr(u + 0x4E00) for u in units]
    unicode_text = "".join(unicode_chars)
    ids = torch.tensor(sp.EncodeAsIds(unicode_text), dtype=torch.long).cuda()
    ids_continued = model.generate(
        ids.unsqueeze(0),
        block_size - len(units),
        temperature=temperature,
        top_k=top_k,
    )
    unicode_text_continued = sp.DecodeIds(ids_continued.squeeze(0).tolist())
    units_continued = [ord(c) - 0x4E00 for c in unicode_text_continued]
    units_continued = [u for u in units_continued if u >= 0 and u < n_units]
    units_continued = torch.tensor(units_continued, dtype=torch.long).cuda()
    wav_continued, sr = gslm.decode(units_continued)
    return Audio(wav_continued.cpu(), rate=sr)

In [24]:
def generate_continuation_with_beam_search(wav, sr, beam_size=None, temperature=1.0):
    wav = wav.to("cuda")
    units = gslm.encode(wav, sr)
    unicode_chars = [chr(u + 0x4E00) for u in units]
    unicode_text = "".join(unicode_chars)
    ids = torch.tensor(sp.EncodeAsIds(unicode_text), dtype=torch.long).cuda()
    ids_continued = model.generate_beam_search(
        ids.unsqueeze(0),
        block_size - len(units),
        beam_size=beam_size,
        temperature=temperature,
    )
    unicode_text_continued = sp.DecodeIds(ids_continued.squeeze(0).tolist())
    units_continued = [ord(c) - 0x4E00 for c in unicode_text_continued]
    units_continued = [u for u in units_continued if u >= 0 and u < n_units]
    units_continued = torch.tensor(units_continued, dtype=torch.long).cuda()
    wav_continued, sr = gslm.decode(units_continued)
    return Audio(wav_continued.cpu(), rate=sr)

In [25]:
dataset = LIBRISPEECH("/home/nicolvisser/datasets/", url="dev-clean", download=False)
wav, sr, *_ = dataset[400]
wav = wav[:, :]
Audio(wav, rate=sr)

In [35]:
generate_unconditional(wav, sr, temperature=0.03, top_k=50)

In [27]:
generate_continuation(wav, sr, temperature=0.03, top_k=50)