In [80]:
import sentencepiece as spm
import torch
from ulm.module import LitGPT, GPTConfig
import json
import torch
from torchaudio.datasets import LIBRISPEECH
from IPython.display import Audio

In [81]:
block_size = 512

In [82]:
n_units = 500
dp_lambda = 0

gslm = torch.hub.load(
    "nicolvisser/gslm-hubert-hifigan:master",
    "gslm",
    n_units=n_units,
    dp_lambda=dp_lambda,
    trust_repo=True,
)

Using cache found in /home/nicolvisser/.cache/torch/hub/nicolvisser_gslm-hubert-hifigan_master
Using cache found in /home/nicolvisser/.cache/torch/hub/bshall_hubert_main
Using cache found in /home/nicolvisser/.cache/torch/hub/nicolvisser_hubert-kmeans_main
Using cache found in /home/nicolvisser/.cache/torch/hub/nicolvisser_duration-predictor_main
Using cache found in /home/nicolvisser/.cache/torch/hub/nicolvisser_acoustic-model_main
Using cache found in /home/nicolvisser/.cache/torch/hub/bshall_hifigan_main


In [83]:
dataset = LIBRISPEECH("/home/nicolvisser/datasets/", url="dev-clean", download=False)

wav, sr, *_ = dataset[300]


wav = wav[:, :34000]

Audio(wav, rate=sr)

In [84]:
units = gslm.encode(wav, sr)
units

tensor([ 42, 499, 213, 391, 213, 391, 213, 391, 213, 497, 426, 446,  17, 305,
        154, 497, 394, 280, 472, 212, 396, 205, 395,  72,  33, 476,   5, 387,
        235, 149, 422, 263, 120, 440, 364, 100, 468, 278,  65, 336, 380,  57,
         29, 414, 479, 399, 358,  56,  84,  98,  25,  32,  16, 424, 318, 192,
        352, 169, 113, 227, 320,  49, 352, 299, 393, 322, 207, 261,  63, 251])

In [85]:
unicode_chars = [chr(u + 0x4E00) for u in units]
unicode_text = "".join(unicode_chars)
unicode_text

'个俳仕侇仕侇仕侇仕俱侪侾丑伱亚俱侊优俘仔侌仍例么両俜丅侃仫井侦伇乸侸佬乤俔伖乁佐佼丹丝侞俟侏佦丸乔乢丙丠丐侨伾什你亩乱代佀丱你伫侉佂仏伅丿任'

In [86]:
sp = spm.SentencePieceProcessor(
    model_file=f"/home/nicolvisser/workspace/ulm/data/n_units-{n_units}-dp_lambda-{dp_lambda}/character_level.model"
)

ids = torch.tensor(sp.EncodeAsIds(unicode_text), dtype=torch.long).cuda()
ids

tensor([123, 120, 114, 336, 287, 336, 287, 336, 287, 336, 369, 376, 375,  26,
         36, 339, 369, 172, 298, 111, 189,  21,  72, 288,   8,  12,  84, 121,
        274, 258, 325, 356, 206, 269, 225, 226, 374, 211,   3, 249, 153, 175,
         43,  29, 109, 416,  85, 196,  98, 348, 197, 129, 108, 385, 110, 228,
         70,  23,  96, 450, 103,  37,   4,  23,  31,  24, 162, 187, 279, 256,
        349], device='cuda:0')

In [87]:
ids.max()

tensor(450, device='cuda:0')

In [88]:
with open(f"/home/nicolvisser/workspace/ulm/data/n_units-{n_units}-dp_lambda-{dp_lambda}/config.json", "r") as f:
    config = GPTConfig(**json.load(f))

model = LitGPT.load_from_checkpoint(
    f"/home/nicolvisser/workspace/ulm/lightning_logs/n_units-{n_units}-dp_lambda-{dp_lambda}/checkpoints/epoch=26-step=8424.ckpt",
    config=config,
)

number of parameters: 38.09M


In [89]:
ids_continued = model.generate(ids.unsqueeze(0), block_size-len(units), temperature=1.0)
ids_continued

tensor([[123, 120, 114, 336, 287, 336, 287, 336, 287, 336, 369, 376, 375,  26,
          36, 339, 369, 172, 298, 111, 189,  21,  72, 288,   8,  12,  84, 121,
         274, 258, 325, 356, 206, 269, 225, 226, 374, 211,   3, 249, 153, 175,
          43,  29, 109, 416,  85, 196,  98, 348, 197, 129, 108, 385, 110, 228,
          70,  23,  96, 450, 103,  37,   4,  23,  31,  24, 162, 187, 279, 256,
         349,  20, 253,   4,  23,  31,  24, 162, 279, 256, 116, 349, 284, 319,
          13,  17, 287,  11,  32, 159, 236,  54,  97, 183,  16, 255, 142, 153,
         421,  45,  68,  89, 337,  92, 211, 298, 131, 111, 143, 186, 288, 151,
         417,  90, 125, 141, 260, 400, 256, 349,  20,  55, 170, 381,  55, 169,
           3,  40, 250, 381, 404, 371, 254,  95, 353, 263,  60,  41,  18, 198,
         108, 212, 451, 178, 257,  69, 270, 147, 396, 177,  83,  93, 171,  87,
         300, 408,  78,  77, 253,  23,  96, 207, 266, 167, 174, 157, 255,  13,
          17, 287,  32, 159, 236, 248,  54,  47,  16

In [90]:
assert (ids_continued < 3).sum() == 0

In [91]:
unicode_text_continued = sp.DecodeIds(ids_continued.squeeze(0).tolist())
unicode_text_continued

'个俳仕侇仕侇仕侇仕俱侪侾丑伱亚俱侊优俘仔侌仍例么両俜丅侃仫井侦伇乸侸佬乤俔伖乁佐佼丹丝侞俟侏佦丸乔乢丙丠丐侨伾什你亩乱代佀丱你伫侉佂仏伅丿任乜乪丱你伫侉佂伅丿俪任伡九俄企侇仮侳乄乙亰书伞亘侂乺佐侧佞仲丫乫伊俔优伜俘他侍例仂两亖俌俎仸佟丿任乜伤乲俀伤佑伖乇侴俀俯伦佻伽仓乑仺伙俁信丠侜佧仿俓伋乶佫仐之于乖亃侢亓乞伭侤乪你亩俦乾也介一侂俄企侇侳乄乙伉亰伎亘侂侽係中佹乜乪丱你伫侉佂从亶俜乹亾伻代三亊乓乜三亊乓亠伦佻伽仉丬併亯令仔俬亗侖伄俪伙侠乀乧亩丫侔从乽乸乀乧亊伀佂乂人伵仄侯佅伝介一侂企侇佨乄乙亰伎亘侽仚亊乓伹佝佞乨丫伪佑伩乇侴俀丕丱乳侉佂侗侖享乵侮修仛丞仱佾串串保仢代佀丱佥丂侕丣仆临俋两丏佄也介俄仮亏侂乗丌侈丶佤亨互侧佞你侵丳丬伐侔伵仄侯信丠侜令乆丛乊于乖亃侭侃乎乬侦佛佅侍修亱丮仯侂仮侳乄乙亰企侽仚亊乓佹依亍仦佞仲丙丠侜佧他乵俨伷伫侉仏伭侤伤丸丸乔亗侖乿亭以俊俑伡依予乁佼传丝串保侖仭俞似伩佃俲上人住佢乌侖仭俜乲仆临便乚佝仧俠仨亡一伯企侀侇侀侇仅佣仠 个俳书亘伱侂侊侱些供东伽仱么両代乆丛佂侗侖享侌仍侢仳侻佟丿任乜伤俜乹亠侌仍俨伫仏便伥伇乸丠侜令乆丛乊之于乖乾也介九佯亥乗侙义乵串亨互俬侁之于乖亃仱么両代佀你伫侉仏亲伥'

In [92]:
units_continued = [ord(c) - 0x4E00 for c in unicode_text_continued]
units_continued = [u for u in units_continued if u >= 0 and u < n_units]
units_continued = torch.tensor(units_continued, dtype=torch.long).cuda()
units_continued

tensor([ 42, 499, 213, 391, 213, 391, 213, 391, 213, 497, 426, 446,  17, 305,
        154, 497, 394, 280, 472, 212, 396, 205, 395,  72,  33, 476,   5, 387,
        235, 149, 422, 263, 120, 440, 364, 100, 468, 278,  65, 336, 380,  57,
         29, 414, 479, 399, 358,  56,  84,  98,  25,  32,  16, 424, 318, 192,
        352, 169, 113, 227, 320,  49, 352, 299, 393, 322, 207, 261,  63, 251,
         92, 106,  49, 352, 299, 393, 322, 261,  63, 490, 251, 289,  93, 452,
        257, 391, 238, 435,  68,  89, 176, 102, 286, 152, 386, 122, 336, 423,
        350, 242,  43, 107, 266, 468, 280, 284, 472, 214, 397, 395, 194,  36,
        150, 460, 462, 248, 351,  63, 251,  92, 292, 114, 448, 292, 337, 278,
         71, 436, 448, 495, 294, 379, 317, 211,  81, 250, 281, 449, 481,  32,
        412, 359, 255, 467, 267, 118, 363, 208,  75, 142,  86, 131, 418, 147,
         94, 301, 420, 106, 352, 169, 486, 126,  95, 203,   0, 386, 452, 257,
        391, 435,  68,  89, 265, 176, 270, 152, 386, 445, 450,  

In [93]:
gslm = gslm.to("cuda")

In [94]:
wav_continued, sr = gslm.decode(units_continued)

In [95]:
Audio(wav_continued.cpu(), rate=sr)