In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

In [5]:
print(type(model))

<class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>


In [13]:
import torch

In [67]:
tokens = list(tokenizer.vocab.keys())

In [124]:
import re

re_hex_color_prefix = re.compile(r"[0-9a-fA-F]{1,6}")

# filter tokens only for tokens that match the above regex
allowed_tokens = [i for i,token in enumerate(tokens) if re_hex_color_prefix.fullmatch(token)]
mask = torch.zeros(len(tokens), dtype=torch.long)
mask[allowed_tokens] = 1

In [109]:
#re_hex_color_prefix.match("00") 
res = re_hex_color_prefix.match("FFAAAA")
res

<re.Match object; span=(0, 6), match='FFAAAA'>

In [132]:
class Prompter:
    def __init__(self, seq):
        self.seq = seq
        tokenizer_result = tokenizer(seq)
        self.input_ids = torch.tensor(tokenizer_result["input_ids"])
        self.attention_mask = torch.tensor(tokenizer_result["attention_mask"])

    def step(self, mask=None):
        res = model.forward(input_ids=self.input_ids, attention_mask=self.attention_mask)
        logits = res.logits[-1]
        # if mask is not None: logits[mask] = 0
        next_token_id = logits[-1].argmax()
        self.input_ids = torch.cat((self.input_ids, torch.tensor([next_token_id]))).to(torch.long)
        self.attention_mask = torch.cat((self.attention_mask, torch.tensor([1])))
        self.seq = tokenizer.decode(self.input_ids.tolist())

    def steps(self, n, until=None, mask=None):
        for i in range(n): 
            self.step(mask=mask)
            if until is not None and self.seq.endswith(until): 
                break

    def result(self):
        return self.seq

    def prompt(self, s):
        self.seq += s
        tokenizer_result = tokenizer(self.seq)
        self.input_ids = torch.tensor(tokenizer_result["input_ids"])
        self.attention_mask = torch.tensor(tokenizer_result["attention_mask"])

p = Prompter("In this tutorial you will learn about hex encoding of colors.")
p.steps(10)
p.prompt(". For example, the hex code for red is #FF0000. For green the hex code is different: ")
p.steps(12, until=".")

print(p.result())

In this tutorial you will learn about hex encoding of colors.!!!!!!!!!!. For example, the hex code for red is #FF0000. For green the hex code is different:!!!!!!!!!!!!
