In [None]:
from hfppl  import Model, CachedCausalLM, Token, LMContext, smc_standard, TokenCategorical
from string import punctuation
import numpy as np

In [None]:
LLM = CachedCausalLM.from_pretrained("lmsys/Vicuna-7b-v1.5", load_in_8bit=True)

In [None]:
class InfillingModel(Model):
    def __init__(self, LLM, prompt, max_tokens):
        super().__init__()
        self.parts = prompt.split("[BLANK]")
        self.lm = LMContext(LLM, self.parts[0])
        self.max_tokens = max_tokens
        self.current_part_index = 1
        self.generated_text = self.parts[0]  # Initialize with the first part of the prompt

    async def step(self):
        if self.current_part_index >= len(self.parts):
            self.finish()
            print(f"Generated: {self.generated_text}")
            return

        n = self.sample_geom(0.5) + 1  # Number of tokens to generate
        for _ in range(n):
            token = await self.sample(self.lm.next_token(), proposal=self.lm.next_token())
            self.generated_text += str(token)
            # Update LMContext with the newly generated token
            self.lm = LMContext(LLM, self.generated_text)
            await self.observe(self.lm.next_token(), token.token_id)

        # Add the next part of the prompt after the blank and update LMContext
        if self.current_part_index < len(self.parts):
            self.generated_text += self.parts[self.current_part_index]
            self.lm = LMContext(self.lm.lm, self.generated_text)
            self.current_part_index += 1

    def sample_geom(self, p):
        return np.random.geometric(p) - 1

In [None]:
prompt = """Yesterday I went to the country [BLANK] and bought a [BLANK]"""

LLM.cache_kv(LLM.tokenizer.encode(prompt))

In [None]:
async def run_infilling():
    infilling_model = InfillingModel(LLM=LLM, prompt=prompt, max_tokens=50)
    particles = await smc_standard(infilling_model, 20)
    for p in particles:
        print(p.generated_text)

In [None]:
await run_infilling()

In [None]:
MASKS = {i : set(j for (j,v) in enumerate(LLM.vocab)
                 if j != LLM.tokenizer.eos_token_id and '\n' not in v and
                 any(c.isalpha() or c in punctuation for c in v) and
                 len(v.strip()) <= 5 and (not v[0].isalpha() or i+len(v) <= 5))
             for i in range(6)}

class ConstraintModel(Model):
    def __init__(self, prompt, max_tokens):
        super().__init__()
        self.lm         = LMContext(LLM, prompt)
        self.q          = LMContext(LLM, prompt)
        self.prompt_len = len(str(self.lm.s))
        self.max_tokens = max_tokens


    async def step(self):
        # Which tokens are allowed?
        mask = self.active_constraint_mask()

        # Generate proposed token.
        token = await self.sample(self.lm.next_token(),
                                  proposal = await self.proposal(mask))

        # Condition on constraint — a no-op since proposal already guarantees the constraint
        self.condition(token.token_id in mask)

        # Reduce number of max tokens remaining
        self.max_tokens -= 1

        print(str(self.lm.s)[self.prompt_len:])

        # Check if done
        if token == LLM.tokenizer.eos_token_id or self.max_tokens == 0:
            self.finish()

    def active_constraint_mask(self):
        string_so_far = str(self.lm.s)
        words = string_so_far.split()
        last_word = words[-1] if len(words) > 0 else ""
        return MASKS[min(5, len(last_word))]

    async def proposal(self, mask):
        string_so_far = str(self.lm.s)

        # Force the proposal StatefulLM to adhere to this mask
        await self.intervene(self.q.mask_dist(mask), True)

        # Return the proposal's modified next-token distribution
        return self.q.next_token()