In [10]:
import transformers
import torch

In [82]:
# model_ref = "hf-internal-testing/tiny-xlm-roberta" # "xlm-roberta-base" #  "gpt2" # 
model_ref = "GroNLP/gpt2-small-dutch"

In [83]:
model = transformers.AutoModelForCausalLM.from_pretrained(model_ref, is_decoder=True)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_ref)

In [90]:
# input_str = "One example of a long word is"
input_str = "Ik heb"
input_ids = tokenizer(input_str, return_tensors="pt").input_ids

In [91]:
# Drop EOS token
if input_ids[0, -1] == tokenizer.eos_token_id:
    input_ids = input_ids[:, :-1]

In [92]:
beam_scorer = transformers.BeamSearchScorer(
    batch_size=1,
    num_beams=10,
    num_beam_hyps_to_keep=10,
    device=model.device)

In [93]:
from transformers import StoppingCriteriaList, MaxLengthCriteria
generated = model.beam_search(
    torch.cat([input_ids] * beam_scorer.num_beams),
    beam_scorer,
    stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=6)])
)

In [100]:
generated = model.generate(input_ids, do_sample=True, temperature=5.0, max_length=6, num_beams=10, num_return_sequences=10)

In [101]:
generated

tensor([[  533,   350,   200,   800,   422,  1283],
        [  533,   350,   287,  1480,   228,   200],
        [  533,   350,   363,   607,  4462,   462],
        [  533,   350,   200,   815,  4894,    14],
        [  533,   350,   177,  1308,   244,   360],
        [  533,   350,   423,   759,   386,   195],
        [  533,   350,   233,   228, 12033,   254],
        [  533,   350,   606,   551,   649,  1194],
        [  533,   350,   414,   228,  1510,  1659],
        [  533,   350,   800,   551,   195,   177]])

In [102]:
[" ".join(tokenizer.convert_ids_to_tokens(ids)) for ids in generated]

['Ik Ġheb Ġhet Ġnooit Ġmeer Ġgezien',
 'Ik Ġheb Ġhaar Ġgezegd Ġdat Ġhet',
 "Ik Ġheb Ġook Ġiets Ġontdekt ,'",
 'Ik Ġheb Ġhet Ġaltijd Ġgeweten ,',
 'Ik Ġheb Ġde Ġauto Ġvoor Ġme',
 'Ik Ġheb Ġdaar Ġheel Ġwat Ġvan',
 'Ik Ġheb Ġal Ġdat Ġgedoe Ġmet',
 'Ik Ġheb Ġhier Ġveel Ġmensen Ġnodig',
 'Ik Ġheb Ġwel Ġdat Ġgevoel Ġgehad',
 'Ik Ġheb Ġnooit Ġveel Ġvan Ġde']

In [48]:
tokenizer.batch_decode(generated)

['<s> One example of a long word is</s></s>',
 '<s> One example of a long word is</s><s>',
 '<s> One example of a long word is</s>.']