In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList
model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2')

In [2]:
prompt = 'What is the capital of France? '
max_tokens = 20
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
output = model.generate(input_ids, max_new_tokens=max_tokens)
tokenizer.batch_decode(output, skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['What is the capital of France? \xa0The capital of France is the capital of France. \xa0The capital of France is the capital']

In [3]:
processors = []

In [4]:
from axtk.generation_utils.logits_processors.token_healing_logits_processor import TokenHealingLogitsProcessor
healer = TokenHealingLogitsProcessor(input_ids[0], tokenizer)
healed_token_ids = healer.healed_token_ids
if len(healed_token_ids) > 0:
    input_ids = input_ids[:, :-len(healed_token_ids)]
    max_tokens += len(healed_token_ids)
    processors.append(healer)

In [5]:
from axtk.generation_utils import RegexLogitsProcessor
proc = RegexLogitsProcessor(r'Paris|London|Berlin', prefix_length=len(prompt), stop_regex='', tokenizer=tokenizer)
processors.append(proc)

In [6]:
procesors = LogitsProcessorList(processors)
output = model.generate(input_ids, logits_processor=processors, max_new_tokens=max_tokens)
tokenizer.batch_decode(output, skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['What is the capital of France? Paris']

In [None]:
tokenizer('What is the capital of France? Paris').input_ids

In [None]:
tokenizer('What is the capital of France? ').input_ids

In [None]:
tokenizer.decode([6342])