In [1]:
import glob

import torch

from src.tokenizer import Tokenizer
from src.dataset import WordsDataset
from src.model import CharLanguageModel

In [4]:
data = WordsDataset._read_data('./data/wikitext-2/wiki.train.tokens')
tokenizer = Tokenizer(data)

model = CharLanguageModel(100, 50, tokenizer.vocab_size, tokenizer.pad_id)
model.eval()

ckpt = glob.glob('./checkpoints/*.ckpt')[0]
model.load_state_dict(torch.load(ckpt, map_location=torch.device('cpu')))

Build vocabulary
Vocab size: 270


<All keys matched successfully>

In [5]:
def generate(model, start_chars, tokenizer, max_length=20, top_k=100):
    encoded = tokenizer.encode(start_chars)['data']
    generated_seq = encoded
    encoded = torch.LongTensor([encoded,])
    while len(generated_seq) < max_length and generated_seq[-1] != tokenizer.eos_id:
        with torch.no_grad():
            next_token_prob = model(encoded)[0, -1, :]
        
        logits, inds = next_token_prob.topk(top_k)
        logits = torch.softmax(logits, dim=-1)
        new_ind = torch.multinomial(logits, 1)
        token_ind = inds[new_ind]
        encoded = torch.cat([encoded, token_ind.unsqueeze(0)], dim=-1)
        token_ind = token_ind.cpu().item()
        generated_seq.append(token_ind)
    return ' '.join(tokenizer.decode(generated_seq))

In [8]:
generate(model, 'He', tokenizer)

'H e r t r r o m'