In [1]:
import constraintlm as clm
import outlines

import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
qwenllm = clm.TransformersLM("Qwen/Qwen2.5-0.5B")

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
Some parameters are on the meta device because they were offloaded to the cpu and disk.


In [3]:
qwenllm_outlines = outlines.models.transformers("Qwen/Qwen2.5-0.5B")

# Length of word <= N

In [3]:
prompts = [
    "In July 1789 the French", 
    "The best basketball player of all time is Michael",
    "Ludwig Wittgenstein was a"
]
batch = qwenllm.tokenizer(prompts, padding=True, return_tensors="pt")

In [4]:
lenword = outlines.processors.RegexLogitsProcessor("\s([A-Za-z0-9]{1,5}[.!?,]?\s)+", qwenllm_outlines.tokenizer)

In [5]:
cons_multinomial = clm.MultinomialSeqSampler(qwenllm, logits_processor=lenword)
cons_generated_token_ids = cons_multinomial.sample(batch.input_ids, max_length=10, top_k=5)
print(qwenllm.tokenizer.batch_decode(torch.cat([batch.input_ids, cons_generated_token_ids], dim=-1)))

1
2
3
4
5
6
7
8
9
['In July 1789 the French army, under the lead of the Duke of O', 'The best basketball player of all time is Michael James, and this is his 40th', 'Ludwig Wittgenstein was a<|endoftext|> 20th c. Swiss lingu, logic']


In [7]:
smc_sampler_fsm = clm.SMCSampler(qwenllm, lenword)
num_particles = 5
max_length = 20
B=len(prompts)
smc_generated_token_ids_fsm = smc_sampler_fsm.sample(batch.input_ids, max_length=max_length, num_particles=num_particles, ess_threshold=10)
for a,b in zip(qwenllm.tokenizer.batch_decode(batch.input_ids.repeat_interleave(num_particles, dim=0)), qwenllm.tokenizer.batch_decode(smc_generated_token_ids_fsm[0].reshape(B*num_particles,max_length))):
    print(a+b)

Resampling...
1
Resampling...
2
Resampling...
3
Resampling...
4
Resampling...
5
Resampling...
6
Resampling...
7
Resampling...
8
Resampling...
9
Resampling...
10
Resampling...
11
Resampling...
12
Resampling...
13
Resampling...
14
Resampling...
15
Resampling...
16
Resampling...
17
Resampling...
18
Resampling...
19
Resampling...
In July 1789 the French first deput for the new era of the City of Paris, Louis XVI.
He thrud to
In July 1789 the French first deput for the new era of the City of Paris, Louis XVI.
He thrud to
In July 1789 the French first deput for the city of Paris and
the first hear of the Bill of Enact in
In July 1789 the French first deput for the city of Paris and
the first hear of the Bill of the Third of
In July 1789 the French first deput for the city of Paris and
the first hear of the Bill of the Third of
The best basketball player of all time is Michael Calab Graza
Diego Gatt
Send an email to AZ Eyes
Last week saw
The best basketball player of all time is Michael Calab

# RPN Typed

In [8]:
import re
from typing import List, Optional, Tuple

_number_or_op_or_var = re.compile(r"\d+|[+\-*/]|[a-zA-Z][a-zA-Z0-9]*")
_var_name = re.compile(r'^[A-Za-z][A-Za-z0-9]*$')

class RPNTypedLogitsProcessor(outlines.processors.base_logits_processor.OutlinesLogitsProcessor):

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.fsm_digitsvar = outlines.processors.RegexLogitsProcessor(r"(\d+|[ +\-*/]|[a-zA-Z][a-zA-Z0-9]*)+", qwenllm_outlines.tokenizer)
    

    def _extract_symbols(self, text: str) -> Optional[List[str]]:
        """
        Parse `text` into a list of RPN tokens (integers and +, -, *, /).
        Returns None if any invalid characters are present.
        """
        matches = list(_number_or_op_or_var.finditer(text))    # list of re.Match objects (it contains the substring, the position of the beginning and the end of the substring)
        symbols = [m.group(0) for m in matches]         # list of substrings that match \d+ or [+\-*/] or ([a-zA-Z][a-zA-Z0-9])+
        cleaned = _number_or_op_or_var.sub("", text)           # remove the match from text, we should have "   " only
        if cleaned.replace(" ", ""):                    # we should obtain "", otherwise text contained a non-digit-nor-operator char
            return None
        return symbols
    

    def prefix(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
        """
        Given the token IDs generated so far (input_ids), 
        return the score associated by the constraint.
        """
        # Decode all sequences at once
        texts = self.tokenizer.batch_decode(
            input_ids, clean_up_tokenization_spaces=False
        )
        scores = []
        
        for text in texts:
            stack = []          # New stack for each sentence
            variables = {}      # New table of variables for each sentence
            
            if not text:
                scores.append(0.0)      # empty context is always a valid prefix
                continue
            symbols = self._extract_symbols(text)
            if symbols is None:             # text contains non-digit-nor-operator char => symbols = None 
                scores.append(float('-inf'))
                continue
            depth = 0
            valid = True
            for sym in symbols:
                if sym.isdigit() or bool(_var_name.match(sym)):
                    depth += 1
                else:
                    # operator
                    if depth < 2:
                        scores.append(float('-inf'))
                        valid = False
                        break
                    depth -= 1
                
                if sym.isdigit():
                    stack.append(float(sym))
                elif bool(_var_name.match(sym)):
                    stack.append(str(sym))
                elif sym == "=":
                    if type(stack[-2]) == float:
                        scores.append(float('-inf'))
                        valid = False
                        break
                    else: # stack[-2] is a string. We never append operators to the stack
                        if type(stack[-1]) == float:    # stack[-1] is a number
                            variables[stack[-2]] = stack[-1]
                            new_value = stack[-1]
                            _ = stack.pop()
                            _ = stack.pop()
                            stack.append(new_value)
                        elif stack[-1] in variables:    # stack[-1] is a defined variable 
                            variables[stack[-2]] = variables[stack[-1]]
                            new_value = variables[stack[-1]]
                            _ = stack.pop()
                            _ = stack.pop()
                            stack.append(new_value)
                        else :                          # stack[-1] is an undefined variable
                            scores.append(float('-inf'))
                            valid = False
                            break
                else:   # any operator that isn't "="
                    if stack[-1] or stack[-2] not in variables: # if there is an undefined variable
                        scores.append(float('-inf'))
                        valid = False
                        break
                    if stack[-1] in variables:                  # if defined, we convert variables into their value
                        stack[-1] = variables[stack[-1]]
                    elif stack[-2] in variables:                # if defined, we convert variables into their value
                        stack[-2] = variables[stack[-2]] 
                    # Now we sould have type(stack[-1]) == float and type(stack[-2]) == float
                    b = stack.pop()
                    a = stack.pop()
                    stack.append(apply_op(a,b, sym))

            if valid:
                scores.append(0.0 if depth >= 1 else float('-inf'))
        return torch.tensor(scores, dtype=torch.float)


    def complete(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
        """
        Given the token IDs of a complete (EOS-terminated) sequence (input_ids), 
        return the score associated by the constraint.
        """
        # Decode all sequences at once
        texts = self.tokenizer.batch_decode(
            input_ids, clean_up_tokenization_spaces=False
        )
        scores = []
        stack = []
        variables = {}
        for text in texts:
            symbols = self._extract_symbols(text)
            if symbols is None:             # text contains non-digit-nor-operator char => symbols = None 
                scores.append(float('-inf'))
                continue
            depth = 0
            valid = True
            for sym in symbols:
                if sym.isdigit() or bool(_var_name.match(sym)):
                    depth += 1
                else:
                    # operator
                    if depth < 2:
                        scores.append(float('-inf'))
                        valid = False
                        break
                    depth -= 1
                
                if sym.isdigit():
                    stack.append(float(sym))
                elif bool(_var_name.match(sym)):
                    stack.append(str(sym))
                elif sym == "=":
                    if type(stack[-2]) == float:
                        valid = False
                        break
                    else: # stack[-2] is a string. We never append operators to the stack
                        if type(stack[-1]) == float:    # stack[-1] is a number
                            variables[stack[-2]] = stack[-1]
                            new_value = stack[-1]
                            _ = stack.pop()
                            _ = stack.pop()
                            stack.append(new_value)
                        elif stack[-1] in variables:    # stack[-1] is a defined variable 
                            variables[stack[-2]] = variables[stack[-1]]
                            new_value = variables[stack[-1]]
                            _ = stack.pop()
                            _ = stack.pop()
                            stack.append(new_value)
                        else :                          # stack[-1] is an undefined variable
                            valid = False
                            break
                else:   # any operator that isn't "="
                    if stack[-1] or stack[-2] not in variables: # if there is an undefined variable
                        valid = False
                        break
                    if stack[-1] in variables:                  # if defined, we convert variables into their value
                        stack[-1] = variables[stack[-1]]
                    elif stack[-2] in variables:                # if defined, we convert variables into their value
                        stack[-2] = variables[stack[-2]] 
                    # Now we sould have type(stack[-1]) == float and type(stack[-2]) == float
                    b = stack.pop()
                    a = stack.pop()
                    stack.append(apply_op(a,b, sym))

            if valid:
                scores.append(0.0 if depth == 1 else float('-inf'))
        return torch.tensor(scores, dtype=torch.float)

    def score(self, batch_ids: torch.LongTensor) -> torch.Tensor:
            B, L = batch_ids.shape

            # Boolean mask of which sequences end in EOS
            is_eos = batch_ids[:, -1] == self.tokenizer.eos_token_id  # (B,)
            idx_eos, idx_pref = torch.nonzero(is_eos, as_tuple=True)[0], torch.nonzero(~is_eos, as_tuple=True)[0]   # indices of sequences

            # Prepare output: one score per sequence
            out = torch.empty(B, device=batch_ids.device, dtype=torch.get_default_dtype())

            # Score the non-EOS prefixes
            if idx_pref.numel() > 0:
                pref_inputs  = batch_ids[idx_pref]          # (Np, L)
                pref_scores  = self.prefix(pref_inputs)     # (Np,) 
                out[idx_pref] = pref_scores

            # Score the EOS-terminated sequences
            if idx_eos.numel() > 0:
                eos_inputs   = batch_ids[idx_eos]           # (Ne, L)
                eos_scores   = self.complete(eos_inputs)    # (Ne,) 
                out[idx_eos] = eos_scores

            return out
    
    def process_logits(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.Tensor:
        # first apply your FSM-based masking
        
        process_logits1 = self.fsm_digitsvar.process_logits(input_ids, logits)

        device, dtype = logits.device, logits.dtype


        # ensure batch dimension
        if logits.dim() == 1:
            process_logits1 = process_logits1.unsqueeze(0)
        B, V = process_logits1.shape

        if input_ids is None or input_ids.numel() == 0:
            input_ids = torch.empty((B, 0), dtype=torch.long, device=device)
        if input_ids.dim() == 1:
            input_ids = input_ids.unsqueeze(0)

        mask = torch.zeros_like(process_logits1, dtype=torch.float32, device=device)
        
        for b in range(B):
            # pick all tokens whose logit != -inf (i.e. not already fully masked)
            candidate_tokens = torch.nonzero(
                process_logits1[b] != float("-inf"),
                as_tuple=True
            )[0]
            if candidate_tokens.numel() == 0:
                print(f"for batch {b}, all tokens are already masked")
                continue

            # build (prefix + each candidate) batches
            seq = input_ids[b]                          # (seqlen,)
            reps = candidate_tokens.size(0)             # # of candidates
            seqs = seq.unsqueeze(0).repeat(reps, 1)     # (reps, seqlen)
            next_ids = candidate_tokens.unsqueeze(1)    # (reps, 1)
            batch = torch.cat([seqs, next_ids], dim=1)  # (reps, seqlen+1)

            # score each continuation: 0 or -inf
            scores = self.score(batch)                       # (reps,)

            # place them into mask
            mask[b, candidate_tokens] = scores

        # add your mask so that disallowed tokens go to -inf
        processed_logits = process_logits1 + mask     # (B, V)
        return processed_logits


def apply_op(a: float, b: float, op: str) -> float:
    if op == '+':
        return a + b
    elif op == '-':
        return a - b
    elif op == '*':
        return a * b
    elif op == '/':
        return a / b
    else:
        raise ValueError(f"Unsupported operator {op!r}")

In [9]:
prompts_rpntyped = [
    "Example 1:\nInput: foo = 4, (3 + foo) * 5\nOutput: 3 foo 4 = + 5 *\n\nExample 2:\nInput: bar = 3, 7 - (2 + bar) * 4\nOutput: 7 2 bar 3 = + 4 * -\n\nExample 3:\nInput: foofoo = 2, 8 + (foofoo * (4 - 1))\nOutput:", 
    "Example 1:\nInput: (3 + 4) * 5\nOutput: 3 4 + 5 *\n\nExample 2:\nInput: bar = 3,  7 - (2 + bar) * 4\nOutput: 7 2 bar 3 = + 4 * -\n\nExample 3:\nInput: foofoo = 5, (3 + 4) * foofoo − 6 \nOutput:", 
]
batch_rpntyped = qwenllm.tokenizer(prompts_rpntyped, padding=True, return_tensors="pt")

In [10]:
rpntyped_c = RPNTypedLogitsProcessor(qwenllm.tokenizer)

In [11]:
smc_sampler_rpntyped = clm.SMCSampler(qwenllm, rpntyped_c)
num_particles = 3
max_length = 10
B=len(prompts_rpntyped)
smc_generated_token_ids_rpntyped = smc_sampler_rpntyped.sample(batch_rpntyped.input_ids, max_length=max_length, num_particles=num_particles, ess_threshold=10)
for a,b in zip(qwenllm.tokenizer.batch_decode(batch_rpntyped.input_ids.repeat_interleave(num_particles, dim=0)), qwenllm.tokenizer.batch_decode(smc_generated_token_ids_rpntyped[0].reshape(B*num_particles,max_length))):
    print("--------------- \n",a+b)

Resampling...
1
Resampling...
2
Resampling...
3
Resampling...
4
Resampling...
5
Resampling...
6
Resampling...
7
Resampling...
8
Resampling...
9
Resampling...
--------------- 
 Example 1:
Input: foo = 4, (3 + foo) * 5
Output: 3 foo 4 = + 5 *

Example 2:
Input: bar = 3, 7 - (2 + bar) * 4
Output: 7 2 bar 3 = + 4 * -

Example 3:
Input: foofoo = 2, 8 + (foofoo * (4 - 1))
Output: Henceforth the enemy is unable to relate to a
--------------- 
 Example 1:
Input: foo = 4, (3 + foo) * 5
Output: 3 foo 4 = + 5 *

Example 2:
Input: bar = 3, 7 - (2 + bar) * 4
Output: 7 2 bar 3 = + 4 * -

Example 3:
Input: foofoo = 2, 8 + (foofoo * (4 - 1))
Output: Henceforth the enemy is unable to find any weapon
--------------- 
 Example 1:
Input: foo = 4, (3 + foo) * 5
Output: 3 foo 4 = + 5 *

Example 2:
Input: bar = 3, 7 - (2 + bar) * 4
Output: 7 2 bar 3 = + 4 * -

Example 3:
Input: foofoo = 2, 8 + (foofoo * (4 - 1))
Output: Henceforth the enemy is unable to find any door
--------------- 
 Example 1:
Input: (3 + 

# RPN

In [12]:
_number_or_op = re.compile(r"\d+|[+\-*/]")

class RPNLogitsProcessor(outlines.processors.base_logits_processor.OutlinesLogitsProcessor):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.fsm_digits = outlines.processors.RegexLogitsProcessor(r"(\d+|[ +\-*/])+", qwenllm_outlines.tokenizer)

    def _extract_symbols(self, text: str) -> Optional[List[str]]:
        """
        Parse `text` into a list of RPN tokens (integers and +, -, *, /).
        Returns None if any invalid characters are present.
        """
        matches = list(_number_or_op.finditer(text))    # list of re.Match objects (it contains the substring, the position of the beginning and the end of the substring)
        symbols = [m.group(0) for m in matches]         # list of substrings that match \d or [+\-*/]
        cleaned = _number_or_op.sub("", text)           # remove the match from text, we should have "   " only
        if cleaned.replace(" ", ""):                    # we should obtain "", otherwise text contained a non-digit-nor-operator char
            return None
        return symbols

    def prefix(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
        """
        Given the token IDs generated so far (input_ids), 
        return the score associated by the constraint.
        """
        # Decode all sequences at once
        texts = self.tokenizer.batch_decode(
            input_ids, clean_up_tokenization_spaces=False
        )
        scores = []
        for text in texts:
            if not text:
                scores.append(0.0)      # empty context is always a valid prefix
                continue
            symbols = self._extract_symbols(text)
            if symbols is None:             # text contains non-digit-nor-operator char => symbols = None 
                scores.append(float('-inf'))
                continue
            depth = 0
            valid = True
            for sym in symbols:
                if sym.isdigit():
                    depth += 1
                else:
                    # operator
                    if depth < 2:
                        scores.append(float('-inf'))
                        valid = False
                        break
                    depth -= 1
            if valid:
                scores.append(0.0 if depth >= 1 else float('-inf'))
        return torch.tensor(scores, dtype=torch.float)

    def complete(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
        """
        Given the token IDs of a complete (EOS-terminated) sequence (input_ids), 
        return the score associated by the constraint.
        """
        texts = self.tokenizer.batch_decode(
            input_ids, clean_up_tokenization_spaces=False
        )
        scores = []
        for text in texts:
            symbols = self._extract_symbols(text)
            if symbols is None:
                scores.append(float('-inf'))
                continue
            depth = 0
            valid = True
            for sym in symbols:
                if sym.isdigit():
                    depth += 1
                else:
                    # operator
                    if depth < 2:
                        scores.append(float('-inf'))
                        valid = False
                        break
                    depth -= 1
            if valid:
                scores.append(0.0 if depth == 1 else float('-inf'))
        return torch.tensor(scores, dtype=torch.float)

    def score(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
        """
        For each sequence in `input_ids`, apply `complete` if it ends with EOS, else `prefix`.
        Vectorized selection via torch.where; no explicit Python loop over batch.
        """
        B, L = input_ids.shape

        # Boolean mask of which sequences end in EOS
        is_eos = input_ids[:, -1] == self.tokenizer.eos_token_id  # (B,)
        idx_eos, idx_pref = torch.nonzero(is_eos, as_tuple=True)[0], torch.nonzero(~is_eos, as_tuple=True)[0]   # indices of sequences

        # Prepare output: one score per sequence
        out = torch.empty(B, device=input_ids.device, dtype=torch.get_default_dtype())

        # Score the non-EOS prefixes
        if idx_pref.numel() > 0:
            pref_inputs  = input_ids[idx_pref]          # (Np, L)
            pref_scores  = self.prefix(pref_inputs)     # (Np,) 
            out[idx_pref] = pref_scores

        # Score the EOS-terminated sequences
        if idx_eos.numel() > 0:
            eos_inputs   = input_ids[idx_eos]           # (Ne, L)
            eos_scores   = self.complete(eos_inputs)    # (Ne,) 
            out[idx_eos] = eos_scores

        return out  # shape: (B,)

    def process_logits(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.Tensor:
        # First apply the regex‐based FSM masking
        logits_after_fsm = self.fsm_digits.process_logits(input_ids, logits)

        device, dtype = logits.device, logits.dtype

        # ensure batch dim
        if logits_after_fsm.dim() == 1:
            logits_after_fsm = logits_after_fsm.unsqueeze(0)
        B, V = logits_after_fsm.shape

        if input_ids is None or input_ids.numel() == 0:
            input_ids = torch.empty((B, 0), dtype=torch.long, device=device)
        if input_ids.dim() == 1:
            input_ids = input_ids.unsqueeze(0)

        # build mask of shape (B, V)
        mask = torch.zeros_like(logits_after_fsm, dtype=torch.float32, device=device)
        
        for b in range(B):
            # tokens not already fully masked
            candidates = torch.nonzero(
                logits_after_fsm[b] != float("-inf"),
                as_tuple=True
            )[0]
            if candidates.numel() == 0:
                continue

            # build batch of prefix + each candidate token
            prefix = input_ids[b]                              # (L,)
            reps = candidates.size(0)                          # R
            prefixes = prefix.unsqueeze(0).repeat(reps, 1)     # (R, L)
            next_ids = candidates.unsqueeze(1)                 # (R, 1)
            batch = torch.cat([prefixes, next_ids], dim=1)     # (R, L+1)

            # compute 0 or -inf per continuation
            sc = self.score(batch)                                  # (R,)

            # fill mask
            mask[b, candidates] = sc

        # apply mask to logits
        return logits_after_fsm + mask
    

In [None]:
    # def score(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
    #     """
    #     Score every sequence in `input_ids` (shape: B × L).

    #     • If the last token *is* EOS → treat sequence as **complete**  
    #       valid IFF final stack depth == 1

    #     • Otherwise                       → treat sequence as **prefix**  
    #       valid IFF final stack depth ≥ 1 (empty prefix is allowed)

    #     Returns a tensor of shape (B,) filled with `0.0` for valid and
    #     `-inf` for invalid sequences.
    #     """
    #     device = input_ids.device
    #     batch  = input_ids.size(0)

    #     # 1. Which rows are complete?
    #     is_complete = input_ids[:, -1] == self.tokenizer.eos_token_id   # (B,)

    #     # 2. Decode once for the whole batch
    #     texts = self.tokenizer.batch_decode(
    #         input_ids, clean_up_tokenization_spaces=False
    #     )

    #     out = torch.empty(batch, dtype=torch.float, device=device)

    #     # 3. Per-sequence check (O(#tokens) but only a single small Python loop)
    #     for i, text in enumerate(texts):

    #         # ------------------------------------ fast-path: empty *prefix*
    #         if not is_complete[i] and text == "":
    #             out[i] = 0.0
    #             continue

    #         symbols = self._extract_symbols(text)
    #         if symbols is None:                    # illegal character
    #             out[i] = float("-inf")
    #             continue

    #         depth, valid = 0, True
    #         for sym in symbols:
    #             if sym.isdigit():                  # push
    #                 depth += 1
    #             else:                              # operator: needs ≥2 operands
    #                 if depth < 2:
    #                     valid = False
    #                     break
    #                 depth -= 1                     # pop two, push one

    #         if not valid:
    #             out[i] = float("-inf")
    #             continue

    #         # ------------------------------------ final stack check
    #         if is_complete[i]:                     # complete expression
    #             out[i] = 0.0 if depth == 1 else float("-inf")
    #         else:                                  # still generating (prefix)
    #             out[i] = 0.0 if depth >= 1 else float("-inf")

    #     return out

In [13]:
prompts_rpn = [
    "Given an arithmetic expression in standard infix notation, it is possible to convert it to Reverse Polish Notation (RPN).\n\nExample 1:\nInput: (3 + 4) * 5\nOutput: 3 4 + 5 *\n\nExample 2:\nInput: 7 - (2 + 3) * 4\nOutput: 7 2 3 + 4 * -\n\nExample 3:\nInput: (8 / 2) + (3 * (4 - 1))\nOutput:", 
    "Given an arithmetic expression in standard infix notation, it is possible to convert it to Reverse Polish Notation (RPN).\n\nExample 1:\nInput: (3 + 4) * 5\nOutput: 3 4 + 5 *\n\nExample 2:\nInput: 7 - (2 + 3) * 4\nOutput: 7 2 3 + 4 * -\n\nExample 3:\nInput: (3 + 4) * 5 − 6 / (1 + 2)\nOutput:", 
    "Given an arithmetic expression in standard infix notation, it is possible to convert it to Reverse Polish Notation (RPN).\n\nExample 1:\nInput: (3 + 4) * 5\nOutput: 3 4 + 5 *\n\nExample 2:\nInput: 7 - (2 + 3) * 4\nOutput: 7 2 3 + 4 * -\n\nExample 3:\nInput: (6 + 2) * 3 − 4\nOutput:", 
]
batch_rpn = qwenllm.tokenizer(prompts_rpn, padding=True, return_tensors="pt")

In [14]:
rpn_c = RPNLogitsProcessor(qwenllm.tokenizer)

In [15]:
smc_sampler_rpn = clm.SMCSampler(qwenllm, rpn_c)
num_particles = 3
max_length = 10
B=len(prompts_rpn)
smc_generated_token_ids_rpn = smc_sampler_rpn.sample(batch_rpn.input_ids, max_length=max_length, num_particles=num_particles, ess_threshold=10)
for a,b in zip(qwenllm.tokenizer.batch_decode(batch_rpn.input_ids.repeat_interleave(num_particles, dim=0)), qwenllm.tokenizer.batch_decode(smc_generated_token_ids_rpn[0].reshape(B*num_particles,max_length))):
    print("--------------- \n",a+b)

Resampling...
1
Resampling...
2
Resampling...
3
Resampling...
4
Resampling...
5
Resampling...
6
Resampling...
7
Resampling...
8
Resampling...
9
Resampling...
--------------- 
 Given an arithmetic expression in standard infix notation, it is possible to convert it to Reverse Polish Notation (RPN).

Example 1:
Input: (3 + 4) * 5
Output: 3 4 + 5 *

Example 2:
Input: 7 - (2 + 3) * 4
Output: 7 2 3 + 4 * -

Example 3:
Input: (8 / 2) + (3 * (4 - 1))
Output:<|endoftext|><|endoftext|><|endoftext|>2 4 - 1 + 3 *
--------------- 
 Given an arithmetic expression in standard infix notation, it is possible to convert it to Reverse Polish Notation (RPN).

Example 1:
Input: (3 + 4) * 5
Output: 3 4 + 5 *

Example 2:
Input: 7 - (2 + 3) * 4
Output: 7 2 3 + 4 * -

Example 3:
Input: (8 / 2) + (3 * (4 - 1))
Output:<|endoftext|><|endoftext|><|endoftext|>2 4 - 1 + 3 *
--------------- 
 Given an arithmetic expression in standard infix notation, it is possible to convert it to Reverse Polish Notation (RPN).

Exa

## Finetuning

Build dataset: infix <=> RPN

In [29]:
import random
import re

OPERATORS  = ['+', '-', '*', '/']
MAX_NUMBER = 99

WEIGHTS = [1.0 / n for n in range(1, MAX_NUMBER + 1)]
NUMBER_RANGE = list(range(1, MAX_NUMBER + 1))

def sample_number():
    return str(random.choices(NUMBER_RANGE, weights=WEIGHTS, k=1)[0])

def gen_expr(depth=0, max_depth=3, parent_op=None):
    # Base case: emit a number
    if depth >= max_depth or random.random() < 0.3:
        return sample_number()
    # Otherwise pick an operator (ban '/' if parent was '/')
    possible_ops = OPERATORS.copy()
    if parent_op == '/':
        possible_ops.remove('/')
    op = random.choice(possible_ops)
    # Build two sub-expressions
    left  = gen_expr(depth+1, max_depth, parent_op=op)
    right = gen_expr(depth+1, max_depth, parent_op=op)
    expr = f"{left}{op}{right}"
    # Randomly wrap in parentheses
    if random.random() < 0.5:
        expr = f"({expr})"
    return expr

def tokenize(expr):
    return re.findall(r'\d+|[()+\-*/]', expr)

def infix_to_postfix(tokens):
    prec = {'+':1, '-':1, '*':2, '/':2}
    out, stack = [], []
    for t in tokens:
        if re.fullmatch(r'\d+', t):
            out.append(t)
        elif t == '(':
            stack.append(t)
        elif t == ')':
            while stack and stack[-1] != '(':
                out.append(stack.pop())
            stack.pop()
        else:  # operator
            while stack and stack[-1] != '(' and prec[t] <= prec[stack[-1]]:
                out.append(stack.pop())
            stack.append(t)
    out.extend(stack[::-1])
    return out

def build_dataset(N, max_depth=3):
    dataset = []
    while len(dataset) < N:
        infix = gen_expr(max_depth=max_depth)
        # reject if no operator at all
        if not re.search(r'[+\-*/]', infix):
            continue
        tokens = tokenize(infix)
        postfix = ' '.join(infix_to_postfix(tokens))
        dataset.append((infix, postfix))
    return dataset

samples = build_dataset(5, max_depth=3)
for infix, postfix in samples:
    print(f"{infix}  ->  {postfix}")


(8*(8/5*1*23))  ->  8 8 5 / 1 * 23 * *
4+2+54+2/41  ->  4 2 + 54 + 2 41 / +
(2-2/2+2)  ->  2 2 2 / - 2 +
(44+((14*29)*(9/1)))  ->  44 14 29 * 9 1 / * +
((6-73)-(48-(16-46)))  ->  6 73 - 48 16 46 - - -


Data creation

In [None]:
from datasets import Dataset

samples = build_dataset(10000, max_depth=2)

data_dicts = [{"infix": infix, "postfix": postfix} for infix, postfix in samples]

ds = Dataset.from_list(data_dicts)

# Create prompt/completion strings
def make_io(example):
    prompt     = "Translate infix to RPN:\n" + example["infix"] + "\nRPN:"
    completion = " " + example["postfix"]  # leading space to separate from tokenizer's BOS
    return {"prompt": prompt, "completion": completion}

ds = ds.map(make_io, remove_columns=["infix","postfix"])

Tokenizer

In [None]:
from transformers import AutoTokenizer

model_name = "Qwen/Qwen2.5-0.5B"
tokenizer  = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# Make sure the tokenizer has a pad token
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})


def compute_len(ex):
    toks = tokenizer(ex["prompt"], ex["completion"], 
                     padding=False, truncation=False)
    return {"length": len(toks["input_ids"])}

ds_lens = ds.map(compute_len, batched=False)
max_len = max(ds_lens["length"])
print(f"Longest example is {max_len} tokens")


def tokenize_fn(example):
    # Truncate/pad as needed; pick max_length to cover prompt+answer
    tokens = tokenizer(
        example["prompt"],
        example["completion"],
        padding="max_length",
        max_length=max_len,
        truncation=False,
    )
    # We only want the model to compute loss on the completion part
    input_ids  = tokens["input_ids"]
    labels     = input_ids.copy()
    # mask out the prompt tokens with -100 so they don’t contribute to loss
    prompt_len = len(tokenizer(example["prompt"])["input_ids"])
    labels[:prompt_len] = [-100] * prompt_len

    return {"input_ids": input_ids, "attention_mask": tokens["attention_mask"], "labels": labels}

tokenized_ds = ds.map(tokenize_fn, batched=False)

In [None]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=max_len)

In [None]:
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    pad_token_id=tokenizer.pad_token_id,
)

training_args = TrainingArguments(
    output_dir="./qwen-finetuned-rpn",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-5,
    fp16=True,
    logging_steps=100,
    save_total_limit=2,
    save_steps=500,
    evaluation_strategy="no",
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds,
    data_collator=data_collator,
)


Train and save model

In [None]:
trainer.train()
trainer.save_model("./qwen-finetuned-rpn")
tokenizer.save_pretrained("./qwen-finetuned-rpn")

In [None]:
qwenllm_rpn_ft = clm.TransformersLM("./qwen-finetuned-rpn")

In [None]:
rpn_c = RPNLogitsProcessor()

In [None]:
smc_sampler_rpn_ft = clm.SMCSampler(qwenllm_rpn_ft, rpn_c)
num_particles = 3
max_length = 10
B=len(prompts_rpn)
smc_generated_token_ids_rpn_ft = smc_sampler_rpn_ft.sample(batch_rpn.input_ids, max_length=max_length, num_particles=num_particles, ess_threshold=10)
for a,b in zip(qwenllm_rpn_ft.tokenizer.batch_decode(batch_rpn.input_ids.repeat_interleave(num_particles, dim=0)), qwenllm_rpn_ft.tokenizer.batch_decode(smc_generated_token_ids_rpn_ft[0].reshape(B*num_particles,max_length))):
    print("--------------- \n",a+b)

## Random model

Compare SFT / LCD / SMC / Random model

With SMC: show how increasing the num_particles increases the log_prob of the more probable sentence. See if it improves result of RPN translation.

IMP : Do I need to implement a way to estimate the normalizing constant when doing rejection sampling?