In [1]:
import torch
from supervoice.tokenizer import Tokenizer
from supervoice.model import SupervoiceGPT
from train_config import config

In [2]:
# Model
tokenizer = Tokenizer(config, "tokenizer.model")
model = SupervoiceGPT(config)
checkpoint = torch.load(f'./output/gpt_new_tokenizer.pt', map_location="cpu")
model.load_state_dict(checkpoint['model'])
model.eval()

SupervoiceGPT(
  (input_embedding): Embedding(16384, 512)
  (transformer): Transformer(
    (layers): ModuleList(
      (0-7): 8 x AttentionBlock(
        (attention_ln): RMSNorm()
        (attention): Linear(in_features=512, out_features=1536, bias=False)
        (attention_output): Linear(in_features=512, out_features=512, bias=False)
        (mlp_ln): RMSNorm()
        (mlp_input): Linear(in_features=512, out_features=2048, bias=True)
        (mlp_output): Linear(in_features=2048, out_features=512, bias=True)
        (mlp_output_dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (skip_combiners): ModuleList()
    (output_norm): RMSNorm()
  )
  (prediction_head): Linear(in_features=512, out_features=16384, bias=False)
)

In [8]:
# Prepare input
# "He was anxious to make this clear."
# ç iː  w ə z  æ ɲ c ʃ ə s  t ə  m ej k  d̪ ɪ s  c ʎ ɪ ɹ
input = "what?".lower()
expected = ""
tokenized = tokenizer.encode(input)
tokenized = [tokenizer.sequence_begin_token_id, tokenizer.text_begin_token_id] + tokenized + [tokenizer.text_end_token_id, tokenizer.phonemes_begin_token_id]
tokenized = torch.tensor(tokenized)
stop_tokens = [tokenizer.sequence_end_token_id]
max_generate_len = 512

In [9]:
# Generate
input = tokenized
res = model.generate(input.unsqueeze(0), max_new_tokens = max_generate_len, stop_tokens = stop_tokens, top_k = 6).squeeze(0)
res = res.tolist()
print(res)
print(tokenizer.decode(res).replace("•", "<SIL>"))

[1, 3, 638, 423, 4, 5, 14, 156, 314, 314, 314, 443, 443, 443, 314, 314, 443, 443, 443, 443, 314, 443, 314, 396, 443, 443, 636, 443, 443, 636, 443, 636, 443, 443, 636, 443, 19, 19, 19, 19, 19, 19, 43, 19, 425, 314, 425, 443, 443, 443, 443, 443, 425, 443, 443, 443, 443, 443, 443, 443, 443, 443, 443, 443, 636, 6, 2]
what?owowowowowowowowɛɛɛɛɛɛɛɛɛɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɹɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɹɹɹɹɹɹɹɹɹɹɹɹɹɹɒɒɒɒɒɒɒɒɒɒɒɹɹɹɹɹɹɹɹɹɹɹɹɹɹɫɫɫɫɫɫɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒdjdjdjdjɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒdjdjdjdjɒɒɒɒɒɒɒɒɒɒɒdjdjdjdjɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒdjdjdjdjɒɒɒɒɒɒɒɒɒɒɒ and and and and and andssssssssssssss andææææææææææææææææɹɹɹɹɹɹɹɹɹɹɹɹɹɹææææææææææææææææɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒææææææææææææææææɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒɒdjdjdjdj


In [5]:
# Generate
# input = torch.tensor([tokenizer.phonemes_begin_token_id])
input = tokenized
total = 16
probs, indices = model.predict_next(input, total, trim_generated = False)
probs = probs.tolist()
for i in range(total):
    ttt = tokenizer.decode([indices.tolist()[i]])
    print(str(indices[i]) + ": " +ttt + ": " + str(probs[i]))

tensor(14): owowowowowowowow: 0.819942057132721
tensor(998): found: 0.012457217089831829
tensor(1051): tjtjtjɪɪɪɪ: 0.008038206957280636
tensor(1781): park: 0.006797855719923973
tensor(165): ææææææææ: 0.006542442832142115
tensor(142): this: 0.005550021305680275
tensor(5214): wwwwɐɐɐ: 0.004891290329396725
tensor(3800): ɪɪɪɪɪddd: 0.00474135298281908
tensor(203): pppppppp: 0.0047226701863110065
tensor(1286): band: 0.004326488822698593
tensor(1816): at: 0.004147302359342575
tensor(5095): ðððððððəəə: 0.003880544798448682
tensor(7036): marked: 0.0035794831346720457
tensor(2847): ççç: 0.003374258056282997
tensor(2137): service: 0.0033320633228868246
tensor(8457): ɑɑɑɑɑɑɑɑɑɑʔʔʔʔʔ: 0.0028107769321650267


In [6]:
# Iterative
input = tokenized

In [7]:
total = 16
probs, indices = model.predict_next(input, total, trim_generated = False)
probs = probs.tolist()
tokens = tokenizer.decode(indices.tolist())
for i in range(total):
    ttt = tokenizer.decode([indices.tolist()[i]])
    print(str(indices[i]) + ": " +ttt + ": " + str(probs[i]))
input = torch.tensor(input.tolist() + [indices[0]]).long()

tensor(14): owowowowowowowow: 0.819942057132721
tensor(998): found: 0.012457217089831829
tensor(1051): tjtjtjɪɪɪɪ: 0.008038206957280636
tensor(1781): park: 0.006797855719923973
tensor(165): ææææææææ: 0.006542442832142115
tensor(142): this: 0.005550021305680275
tensor(5214): wwwwɐɐɐ: 0.004891290329396725
tensor(3800): ɪɪɪɪɪddd: 0.00474135298281908
tensor(203): pppppppp: 0.0047226701863110065
tensor(1286): band: 0.004326488822698593
tensor(1816): at: 0.004147302359342575
tensor(5095): ðððððððəəə: 0.003880544798448682
tensor(7036): marked: 0.0035794831346720457
tensor(2847): ççç: 0.003374258056282997
tensor(2137): service: 0.0033320633228868246
tensor(8457): ɑɑɑɑɑɑɑɑɑɑʔʔʔʔʔ: 0.0028107769321650267
