In [1]:
from typing import List, Dict, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


## Device

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load Model and Tokenizer

In [21]:
pretrained_model_name_or_path = "./models/google--gemma-2-2b-it"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16)
model.to(device)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  9.79it/s]


Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2SdpaAttention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), 

## Set Banned Words

In [38]:
banned_words = ["talk", "listen", "fuck you"]

In [74]:
import copy
import re

new_token_candidates = []
new_token_id_candidates = []

# Define a pattern to filter out unwanted excessive whitespace tokens (e.g., multiple tabs or newlines)
invalid_token_pattern = re.compile(r'^[\t\n\r\f\v\s]*$')  # Matches tokens that are made of 3 or more whitespace characters

# Initial matching of tokens that start with banned word prefixes
for token, token_id in tokenizer.vocab.items():
    token_stripped = token.strip()

    # Filter out tokens that consist of excessive whitespace or control characters
    if invalid_token_pattern.match(token_stripped):
        continue

    for banned_word in banned_words:
        if banned_word.startswith(token_stripped):
            new_token_candidates.append([token])  # Store as a list for consistency
            new_token_id_candidates.append([token_id])

print("Filtered First Candidates:", new_token_candidates)
print("Filtered First IDs:", new_token_id_candidates)

# Initialize final routes for fully matched banned words
final_routes = []
final_id_routes = []
for new_token, new_token_id in zip(new_token_candidates, new_token_id_candidates):
    if "".join(new_token).strip() in banned_words:
        final_routes.append(new_token)
        final_id_routes.append(new_token_id)

# Iteratively expand candidates using BFS-like approach
while new_token_candidates:
    curr_token_candidates = new_token_candidates
    curr_token_id_candidates = new_token_id_candidates
    new_token_candidates = []
    new_token_id_candidates = []

    # Iterate over vocabulary and expand each candidate
    for token, token_id in tokenizer.vocab.items():
        token_stripped = token.strip()

        # Skip tokens that consist of excessive whitespace or control characters
        if invalid_token_pattern.match(token_stripped):
            continue

        for candidate_token, candidate_token_id in zip(curr_token_candidates, curr_token_id_candidates):
            curr_token = candidate_token + [token]
            curr_token_ids = candidate_token_id + [token_id]

            curr_token_str = "".join(curr_token).strip()

            # Check if the current token combination matches or is a prefix of any banned word
            for banned_word in banned_words:
                if curr_token_str == banned_word:
                    # Full match found, add to final routes
                    final_routes.append(curr_token)
                    final_id_routes.append(curr_token_ids)
                elif banned_word.startswith(curr_token_str):
                    # Partial match, keep expanding this candidate
                    new_token_candidates.append(curr_token)
                    new_token_id_candidates.append(curr_token_ids)

print("Final Routes:", final_routes)
print("Final ID Routes:", final_id_routes)

# Optimizations:
# 1. Added regex filtering (`invalid_token_pattern`) to filter out tokens made of excessive whitespace or control characters.
# 2. Removed unnecessary deep copy (`copy.deepcopy`), instead directly reassign lists which is more efficient.
# 3. Combined stripping and joining operations to reduce redundant code.
# 4. Improved readability by adding meaningful comments for each logical section.

Filtered First Candidates: [['li'], ['fuck'], ['tal'], ['l'], ['lis'], ['ta'], ['liste'], ['t'], ['talk'], ['listen'], ['f'], ['list'], ['fu']]
Filtered First IDs: [[515], [34024], [3559], [235257], [15063], [516], [44003], [235251], [33085], [18998], [235266], [1701], [12819]]
Final Routes: [['talk'], ['listen'], ['li', 'sten'], ['l', 'isten'], ['tal', 'k'], ['liste', 'n'], ['list', 'en'], ['lis', 'ten'], ['ta', 'lk'], ['t', 'alk'], ['l', 'i', 'sten'], ['ta', 'l', 'k'], ['t', 'al', 'k'], ['l', 'iste', 'n'], ['lis', 'te', 'n'], ['li', 'ste', 'n'], ['list', 'e', 'n'], ['li', 'st', 'en'], ['l', 'ist', 'en'], ['lis', 't', 'en'], ['l', 'is', 'ten'], ['li', 's', 'ten'], ['t', 'a', 'lk'], ['t', 'a', 'l', 'k'], ['l', 'is', 'te', 'n'], ['li', 's', 'te', 'n'], ['l', 'i', 'ste', 'n'], ['li', 'st', 'e', 'n'], ['l', 'ist', 'e', 'n'], ['lis', 't', 'e', 'n'], ['l', 'i', 'st', 'en'], ['l', 'is', 't', 'en'], ['li', 's', 't', 'en'], ['l', 'i', 's', 'ten'], ['l', 'i', 's', 'te', 'n'], ['l', 'i', 'st', '

In [71]:
for items, ids in zip(final_routes, final_id_routes):
    print(items, ids)

['talk'] [33085]
['listen'] [18998]
['li', 'sten'] [515, 5547]
['l', 'isten'] [235257, 17071]
['tal', 'k'] [3559, 235273]
['liste', 'n'] [44003, 235254]
['list', 'en'] [1701, 479]
['lis', 'ten'] [15063, 965]
['ta', 'lk'] [516, 26159]
['t', 'alk'] [235251, 2071]
['l', 'i', 'sten'] [235257, 235252, 5547]
['ta', 'l', 'k'] [516, 235257, 235273]
['t', 'al', 'k'] [235251, 492, 235273]
['l', 'iste', 'n'] [235257, 3671, 235254]
['lis', 'te', 'n'] [15063, 488, 235254]
['li', 'ste', 'n'] [515, 2855, 235254]
['list', 'e', 'n'] [1701, 235249, 235254]
['li', 'st', 'en'] [515, 490, 479]
['l', 'ist', 'en'] [235257, 694, 479]
['lis', 't', 'en'] [15063, 235251, 479]
['l', 'is', 'ten'] [235257, 502, 965]
['li', 's', 'ten'] [515, 235256, 965]
['t', 'a', 'lk'] [235251, 235250, 26159]
['t', 'a', 'l', 'k'] [235251, 235250, 235257, 235273]
['l', 'is', 'te', 'n'] [235257, 502, 488, 235254]
['li', 's', 'te', 'n'] [515, 235256, 488, 235254]
['l', 'i', 'ste', 'n'] [235257, 235252, 2855, 235254]
['li', 'st', 'e',

## Normal Decoding

In [41]:
input_text = "Can we talk?"
inputs = tokenizer(input_text, return_tensors="pt").to(device)

In [42]:
def custom_generate(input_ids: torch.Tensor, max_length: int = 50) -> str:
    for _ in range(max_length):
        # Generate new tokens
        outputs = model(input_ids.to(device), return_dict=True, use_cache=True)
        logits = outputs.logits[:, -1, :]
        
        new_generated_token_id = torch.argmax(logits, dim=-1)

        if new_generated_token_id == tokenizer.eos_token_id:
            break

        input_ids = torch.cat((input_ids, new_generated_token_id.unsqueeze(0)), dim=-1)

    return tokenizer.decode(input_ids[0], skip_special_tokens=False)

In [43]:
print(custom_generate(input_ids=inputs.input_ids))

<bos>Can we talk?

I'm here to listen and help in any way I can. 

What's on your mind? 
<end_of_turn>


## FSM

In [56]:
class FSMProcessor:
    def __init__(self, special_token_ids_list: List[List[int]], end_state: int = -1) -> None:
        self.end_state = end_state
        self.next_state = 1
        self.curr_state = 0
        self.fsm = {}
        self.special_words = []

        # Track partial matches
        self.partial_match_state = None
        self.partial_tokens = []

        self.update_group(special_token_ids_list)

    def update(self, special_token_ids: List[int]) -> None:
        curr_state = 0

        for idx, special_token_id in enumerate(special_token_ids):
            if curr_state not in self.fsm:
                self.fsm[curr_state] = []

            state2id = [items[0] for items in self.fsm[curr_state]]
            if special_token_id not in state2id:
                if idx == len(special_token_ids) - 1:
                    self.fsm[curr_state].append([special_token_id, self.end_state])
                else:
                    self.fsm[curr_state].append([special_token_id, self.next_state])
                    curr_state = self.next_state
                    self.next_state += 1
            else:
                for fsm_idx in range(len(self.fsm[curr_state])):
                    if special_token_id == self.fsm[curr_state][fsm_idx][0] and idx == len(special_token_ids) - 1:
                        self.fsm[curr_state][fsm_idx][1] = self.end_state
                        break
                    elif special_token_id == self.fsm[curr_state][fsm_idx][0]:
                        curr_state = self.fsm[curr_state][fsm_idx][1]
                        break

    def update_group(self, special_token_ids_list: List[List[int]]) -> None:
        for special_token_ids in special_token_ids_list:
            self.update(special_token_ids=special_token_ids)

    def get_fsm_data(self) -> Dict[str, List[Tuple[int, int]]]:
        return self.fsm
    
    def detect(self, token: int) -> bool:
        """
        Detect if the current token leads to a sensitive sequence.
        Updates the current state and returns True if it reaches the end state.
        """
        if self.curr_state in self.fsm:
            for transition in self.fsm[self.curr_state]:
                if transition[0] == token:
                    self.curr_state = transition[1]

                    # If the current state reaches the end state
                    return self.curr_state == self.end_state
        
        # If the token does not match, reset the current state
        self.curr_state = 0
        return False

In [57]:
fsm_processor = FSMProcessor(special_token_ids_list=final_id_routes)
fsm_processor.get_fsm_data()

{0: [[33085, -1],
  [18998, -1],
  [515, 1],
  [235257, 2],
  [3559, 3],
  [44003, 4],
  [1701, 5],
  [15063, 6],
  [516, 7],
  [235251, 8]],
 1: [[5547, -1], [2855, 14], [490, 16], [235256, 20]],
 2: [[17071, -1], [235252, 9], [3671, 12], [694, 17], [502, 19]],
 3: [[235273, -1]],
 4: [[235254, -1]],
 5: [[479, -1], [235249, 15]],
 6: [[965, -1], [488, 13], [235251, 18]],
 7: [[26159, -1], [235257, 10]],
 8: [[2071, -1], [492, 11], [235250, 21]],
 9: [[5547, -1], [2855, 25], [490, 29], [235256, 32]],
 10: [[235273, -1]],
 11: [[235273, -1]],
 12: [[235254, -1]],
 13: [[235254, -1]],
 14: [[235254, -1]],
 15: [[235254, -1]],
 16: [[479, -1], [235249, 26]],
 17: [[479, -1], [235249, 27]],
 18: [[479, -1], [235249, 28]],
 19: [[965, -1], [488, 23], [235251, 30]],
 20: [[965, -1], [488, 24], [235251, 31]],
 21: [[26159, -1], [235257, 22]],
 22: [[235273, -1]],
 23: [[235254, -1]],
 24: [[235254, -1]],
 25: [[235254, -1]],
 26: [[235254, -1]],
 27: [[235254, -1]],
 28: [[235254, -1]],
 29:

In [63]:
def custom_generate_with_fsm_filter(
    input_ids: torch.Tensor,
    fsm_processor: FSMProcessor,
    max_length: int = 20,
) -> str:
    # Historical mask list used to record masked tokens at each decoding step
    masked_tokens_history = {}
    past_key_values = None
    steps = 0

    while steps < max_length:
        steps += 1

        # Generate new token with kv cache
        outputs = model(input_ids, past_key_values=past_key_values, return_dict=True, use_cache=True)
        logits = outputs.logits[:, -1, :]

        # Update kv cache
        past_key_values = outputs.past_key_values

        # Check if there are already masked tokens at the current step
        if steps in masked_tokens_history:
            for masked_token_id in masked_tokens_history[steps]:
                logits[:, masked_token_id] = -float("inf")
        else:
            masked_tokens_history[steps] = set()

        # Decode the generated token
        generated_token_id = torch.argmax(logits, dim=-1).item()
        combined_ids = torch.cat((input_ids, torch.tensor([[generated_token_id]], device=input_ids.device)), dim=-1)

        # Check FSM for sensitive sequences
        if fsm_processor.detect(generated_token_id):
            # Detected a sensitive sequence, initiate rollback
            rollback_length = fsm_processor.partial_match_state + 1 if fsm_processor.partial_match_state is not None else 1
            steps = steps - rollback_length + 1
            rollbacks_ids = combined_ids[:, :-rollback_length]
            input_ids = rollbacks_ids
            print(f"Rollback detected. Rolling back from step {steps + rollback_length} to step {steps}")

            # Reset FSM state
            fsm_processor.curr_state = 0
            fsm_processor.partial_match_state = None

            # Reset past_key_values when rolling back
            past_key_values = None

            # Recalculate logits based on rolled-back sequence
            outputs = model(input_ids, return_dict=True, use_cache=True)
            logits = outputs.logits[:, -1, :]
            past_key_values = outputs.past_key_values

            # Mask the first token of the sensitive sequence
            first_token_id = generated_token_id
            print(f"Masking token id: {first_token_id}, Masking token: {tokenizer.decode(first_token_id)}")

            # Update the historical mask list to record the token at this step
            masked_tokens_history[steps].add(first_token_id)

            for masked_token_id in masked_tokens_history[steps]:
                logits[:, masked_token_id] = -float("inf")

            # Generate the token again after masking
            generated_token_id = torch.argmax(logits, dim=-1).item()

        # Update input_ids with the generated token
        input_ids = torch.cat((input_ids, torch.tensor([[generated_token_id]], device=input_ids.device)), dim=1)

        print(f"Step {steps}: ID: {generated_token_id} Generated token: {tokenizer.decode(generated_token_id)}")

        if generated_token_id == tokenizer.eos_token_id:
            break

    return tokenizer.decode(input_ids[0], skip_special_tokens=False)

In [64]:
print(custom_generate_with_fsm_filter(
    input_ids=inputs.input_ids,
    fsm_processor=fsm_processor,
    max_length=50,
))

Step 1: ID: 109 Generated token: 


Step 2: ID: 235285 Generated token: I
Step 3: ID: 235303 Generated token: '
Step 4: ID: 235262 Generated token: m
Step 5: ID: 1517 Generated token:  here
Step 6: ID: 577 Generated token:  to
Step 7: ID: 10724 Generated token:  listen
Step 8: ID: 578 Generated token:  and
Step 9: ID: 1707 Generated token:  help
Step 10: ID: 575 Generated token:  in
Step 11: ID: 1089 Generated token:  any
Step 12: ID: 1703 Generated token:  way
Step 13: ID: 590 Generated token:  I
Step 14: ID: 798 Generated token:  can
Step 15: ID: 235265 Generated token: .
Step 16: ID: 235248 Generated token:  
Step 17: ID: 109 Generated token: 


Step 18: ID: 1841 Generated token: What
Step 19: ID: 235303 Generated token: '
Step 20: ID: 235256 Generated token: s
Step 21: ID: 611 Generated token:  on
Step 22: ID: 861 Generated token:  your
Step 23: ID: 3403 Generated token:  mind
Step 24: ID: 235336 Generated token: ?
Step 25: ID: 235248 Generated token:  
Step 26: ID: 108 Generated t