In [1]:
import torch
from supervoice.tokenizer import Tokenizer
from supervoice.model import SupervoiceGPT
from train_config import config

In [2]:
# Model
tokenizer = Tokenizer(config)
model = SupervoiceGPT(config)
checkpoint = torch.load(f'./output/gpt_first.pt', map_location="cpu")
model.load_state_dict(checkpoint['model'])
model.eval()

SupervoiceGPT(
  (input_embedding): Embedding(916, 512)
  (transformer): Transformer(
    (layers): ModuleList(
      (0-7): 8 x AttentionBlock(
        (attention_ln): RMSNorm()
        (attention): Linear(in_features=512, out_features=1536, bias=False)
        (attention_output): Linear(in_features=512, out_features=512, bias=False)
        (mlp_ln): RMSNorm()
        (mlp_input): Linear(in_features=512, out_features=2048, bias=True)
        (mlp_output): Linear(in_features=2048, out_features=512, bias=True)
        (mlp_output_dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (skip_combiners): ModuleList()
    (output_norm): RMSNorm()
  )
  (prediction_head): Linear(in_features=512, out_features=916, bias=False)
)

In [7]:
# Prepare input
input = "It's too close to call.".lower()
tokenized = tokenizer(list(input)).tolist()
tokenized = [tokenizer.sequence_begin_token_id, tokenizer.text_begin_token_id] + tokenized + [tokenizer.text_end_token_id, tokenizer.phonemes_begin_token_id]
tokenized = torch.tensor(tokenized)
stop_tokens = [tokenizer.sequence_end_token_id]
max_generate_len = 512

In [8]:
# Generate
input = tokenized
res = model.generate(input.unsqueeze(0), max_new_tokens = max_generate_len, stop_tokens = stop_tokens).squeeze(0)
print(res)
print(tokenizer.reverse(res.tolist()))

tensor([  3,   5, 796, 807, 774, 806, 769, 807, 802, 802, 769, 790, 799, 802,
        806, 792, 769, 807, 802, 769, 790, 788, 799, 799, 779,   6,   7,  99,
         99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99,  99, 179,
        179, 179, 179, 179, 179, 179, 179, 179, 179, 179, 179, 179, 744, 744,
        744, 744, 744, 744, 744, 744, 744, 744, 744, 744, 744, 744, 744, 744,
        744, 179, 179, 179, 179, 179, 179, 179, 179, 179, 179, 179, 179, 179,
        179, 179, 179, 179,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
        450, 450, 450, 450, 450, 450, 450, 450, 450, 450, 450, 450, 450, 450,
        450, 450, 450, 450, 450, 450, 450, 450, 450, 450, 450, 4

In [None]:
# Generate
input = tokenized
res = model.generate(input.unsqueeze(0), max_new_tokens = max_generate_len, stop_tokens = stop_tokens).squeeze(0)
print(res)
print(tokenizer.reverse(res.tolist()))