In [1]:
from typing import List, Dict, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


## Device

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load Model and Tokenizer

In [3]:
pretrained_model_name_or_path = "./models/google--gemma-2-2b-it"
# pretrained_model_name_or_path = "openai-community/gpt2"


tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16)

# model = model.to(device)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  8.96it/s]


## Normal Decoding

In [6]:
input_text = "Can we talk?"
inputs = tokenizer(input_text, return_tensors="pt").to(device)

In [7]:
def custom_generate(input_ids: torch.Tensor, max_length: int = 50) -> str:
    for _ in range(max_length):
        # Generate new tokens
        outputs = model(input_ids, return_dict=True, use_cache=True)
        logits = outputs.logits[:, -1, :]
        
        new_generated_token_id = torch.argmax(logits, dim=-1)

        if new_generated_token_id == tokenizer.eos_token_id:
            break

        input_ids = torch.cat((input_ids, new_generated_token_id.unsqueeze(0)), dim=-1)

    return tokenizer.decode(input_ids[0], skip_special_tokens=False)

In [8]:
print(custom_generate(input_ids=inputs.input_ids))

KeyboardInterrupt: 

## Special Words Filter Decoding V1

In [39]:
def custom_generate_with_special_words_filter(
    input_ids: torch.Tensor,
    special_words: List[str],
    max_length: int = 20,
) -> str:
    for _ in range(max_length):
        # Generate new tokens
        outputs = model(input_ids, return_dict=True, use_cache=True)
        logits = outputs.logits[:, -1, :]

        # Generate token
        generated_token_id = torch.argmax(logits, dim=-1)
        combined_ids = torch.cat((input_ids, generated_token_id.unsqueeze(0)), dim=-1)
        combined_text = tokenizer.decode(combined_ids[0], skip_special_tokens=True)

        if generated_token_id == tokenizer.eos_token:
            break

        # Check for each special word
        for special_word in special_words:
            if special_word in combined_text:
                # Tokenize the special word
                special_word_tokenized = tokenizer(special_word, return_tensors="pt", add_special_tokens=False).input_ids
                special_word_length = special_word_tokenized.shape[1]
                print(f"Detect special word: {special_word}\n Start roll-back process...")

                # Rollback to before the special word
                rollbacks_ids = combined_ids[:, :-special_word_length]
                input_ids = rollbacks_ids
                print(f"{combined_ids.shape[1]} -> {rollbacks_ids.shape[1]}")

                # Recompute logits after rollback
                outputs = model(input_ids, return_dict=True, use_cache=True)
                logits = outputs.logits[:, -1, :]  # Recompute logits based on rolled-back input
                print(logits.shape)

                # Mask only the first token of the special word
                first_token_id = special_word_tokenized[0, 0]
                print(f"Masking token: {tokenizer.decode(first_token_id)}")
                logits[:, first_token_id] = -float("inf")  # Apply mask to the logits

        # Generate new token
        new_generated_token = torch.argmax(logits, dim=-1)
        print(f"Generated token: {tokenizer.decode(new_generated_token)}")
        input_ids = torch.cat((input_ids, new_generated_token.unsqueeze(0)), dim=1)

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


In [40]:
print(custom_generate_with_special_words_filter(
    input_ids=inputs.input_ids,
    special_words=["not sure", "fan"],
))

Generated token: 


Generated token: I
Generated token: '
Generated token: m
Generated token:  here
Generated token:  to
Detect special word: listen
 Start roll-back process...
12 -> 11
torch.Size([1, 256000])
Masking token: listen
Generated token:  listen
Detect special word: listen
 Start roll-back process...
13 -> 12
torch.Size([1, 256000])
Masking token: listen
Generated token:  and
Detect special word: listen
 Start roll-back process...
14 -> 13
torch.Size([1, 256000])
Masking token: listen
Generated token:  help
Detect special word: listen
 Start roll-back process...
15 -> 14
torch.Size([1, 256000])
Masking token: listen
Generated token:  in
Detect special word: listen
 Start roll-back process...
16 -> 15
torch.Size([1, 256000])
Masking token: listen
Generated token:  any
Detect special word: listen
 Start roll-back process...
17 -> 16
torch.Size([1, 256000])
Masking token: listen
Generated token:  way
Detect special word: listen
 Start roll-back process...
18 -> 17
torch.Size([1

In [41]:
print(custom_generate_with_special_words_filter(
    input_ids=inputs.input_ids,
    special_words=["listen"],
))

Generated token: 


Generated token: I
Generated token: '
Generated token: m
Generated token:  here
Generated token:  to
Detect special word: listen
 Start roll-back process...
12 -> 11
torch.Size([1, 256000])
Masking token: listen
Generated token:  listen
Detect special word: listen
 Start roll-back process...
13 -> 12
torch.Size([1, 256000])
Masking token: listen
Generated token:  and
Detect special word: listen
 Start roll-back process...
14 -> 13
torch.Size([1, 256000])
Masking token: listen
Generated token:  help
Detect special word: listen
 Start roll-back process...
15 -> 14
torch.Size([1, 256000])
Masking token: listen
Generated token:  in
Detect special word: listen
 Start roll-back process...
16 -> 15
torch.Size([1, 256000])
Masking token: listen
Generated token:  any
Detect special word: listen
 Start roll-back process...
17 -> 16
torch.Size([1, 256000])
Masking token: listen
Generated token:  way
Detect special word: listen
 Start roll-back process...
18 -> 17
torch.Size([1

In [42]:
print(custom_generate_with_special_words_filter(
    input_ids=inputs.input_ids,
    special_words=["listen", " listen"],
))

Generated token: 


Generated token: I
Generated token: '
Generated token: m
Generated token:  here
Generated token:  to
Detect special word: listen
 Start roll-back process...
12 -> 11
torch.Size([1, 256000])
Masking token: listen
Detect special word:  listen
 Start roll-back process...
12 -> 11
torch.Size([1, 256000])
Masking token:  listen
Generated token:  help
Generated token:  you
Generated token:  with
Generated token:  whatever
Generated token:  you
Generated token:  need
Generated token: .
Generated token:  
Generated token: 


Generated token: Please
Generated token:  tell
Generated token:  me
Generated token:  what
Generated token: '
Can we talk?

I'm here to help you with whatever you need. 

Please tell me what'


In [128]:
tokenizer("listen", return_tensors="pt", add_special_tokens=False).input_ids

tensor([[18998]])

In [11]:
def custom_generate_with_special_words_filter_v2(
    input_ids: torch.Tensor,
    special_words: List[str],
    max_length: int = 20,
) -> str:
    # Historical mask list used to record masked tokens at each decoding step
    masked_tokens_history = {}

    step = 0

    while step < max_length:
        step += 1

        # Generate new tokens
        outputs = model(input_ids, return_dict=True, use_cache=True)
        logits = outputs.logits[:, -1, :]

        # Check if there are already masked tokens at the current step
        if step in masked_tokens_history:
            for token_id in masked_tokens_history[step]:
                logits[:, token_id] = -float("inf")  # Mask previously invalid tokens

        # Generate the token
        generated_token_id = torch.argmax(logits, dim=-1)
        combined_ids = torch.cat((input_ids, generated_token_id.unsqueeze(0)), dim=-1)
        combined_text = tokenizer.decode(combined_ids[0], skip_special_tokens=True)

        if generated_token_id == tokenizer.eos_token_id:
            break

        # Check for each sensitive word
        need_to_generate_again = False

        for special_word in special_words:
            if special_word in combined_text:
                # Convert the sensitive word into tokens
                special_word_tokenized = tokenizer(special_word, return_tensors="pt", add_special_tokens=False).input_ids
                special_word_length = special_word_tokenized.shape[1]
                print(f"Detect special word: {special_word}\n Start roll-back process...")

                # Roll back to before the sensitive word
                rollbacks_ids = combined_ids[:, :-special_word_length]
                input_ids = rollbacks_ids
                print(f"Roll back from {combined_ids.shape[1]} to {rollbacks_ids.shape[1]}")

                # Recalculate logits based on the rolled-back sequence
                outputs = model(input_ids, return_dict=True, use_cache=True)
                logits = outputs.logits[:, -1, :]  # Recalculate logits based on rolled-back input

                # Only mask the first token of the sensitive word
                first_token_id = special_word_tokenized[0, 0]
                print(f"Masking token: {tokenizer.decode(first_token_id)}")
                logits[:, first_token_id] = -float("inf")  # Mask the first token of the sensitive word

                # Update the historical mask list to record the token at this step
                if step not in masked_tokens_history:
                    masked_tokens_history[step] = set()
                masked_tokens_history[step].add(first_token_id)

                need_to_generate_again = True

        # Generate the next token
        if need_to_generate_again:
            generated_token_id = torch.argmax(logits, dim=-1)
            print(f"Generated token: {tokenizer.decode(generated_token_id)}")

        input_ids = torch.cat((input_ids, generated_token_id.unsqueeze(0)), dim=1)

    print()
    from pprint import pprint
    pprint(masked_tokens_history)
    print()

    return tokenizer.decode(input_ids[0], skip_special_tokens=False)


In [14]:
print(custom_generate_with_special_words_filter_v2(
    input_ids=inputs.input_ids,
    special_words=["not sure", "  not"],
    max_length=100,
))

Detect special word: not sure
 Start roll-back process...
Roll back from 10 to 8
Masking token: not
Generated token:  not
Detect special word: not sure
 Start roll-back process...
Roll back from 10 to 8
Masking token: not
Generated token:  not
Detect special word: not sure
 Start roll-back process...
Roll back from 10 to 8
Masking token: not
Generated token:  not
Detect special word: not sure
 Start roll-back process...
Roll back from 10 to 8
Masking token: not
Generated token:  not
Detect special word: not sure
 Start roll-back process...
Roll back from 10 to 8
Masking token: not
Generated token:  not
Detect special word: not sure
 Start roll-back process...
Roll back from 10 to 8
Masking token: not
Generated token:  not
Detect special word: not sure
 Start roll-back process...
Roll back from 10 to 8
Masking token: not
Generated token:  not
Detect special word: not sure
 Start roll-back process...
Roll back from 10 to 8
Masking token: not
Generated token:  not
Detect special word: not

In [4]:
def custom_generate_with_special_words_filter_v3(
    input_ids: torch.Tensor,
    special_words: List[str],
    max_length: int = 20,
) -> str:
    # Historical mask list used to record masked tokens at each decoding step
    masked_tokens_history = {}
    past_key_values = None
    steps = 0

    while steps < max_length:
        steps += 1

        # Generate new token with kv cache
        outputs = model(input_ids, past_key_values=past_key_values, return_dict=True, use_cache=True)
        logits = outputs.logits[:, -1, :]

        # Update kv cache
        past_key_values = outputs.past_key_values

        # Check if there are already masked tokens at the current steps
        if steps in masked_tokens_history:
            for masked_token_id in masked_tokens_history[steps]:
                logits[:, masked_token_id] = -float("inf")
        else:
            masked_tokens_history[steps] = set()

        # Decode the generated token
        generated_token_id = torch.argmax(logits, dim=-1)
        combined_ids = torch.cat((input_ids, generated_token_id.unsqueeze(0)), dim=-1)
        combined_text = tokenizer.decode(combined_ids[0], skip_special_tokens=True)

        print(steps, combined_text)

        for special_word in special_words:
            special_word_length = len(special_word)
            if special_word in combined_text[-special_word_length:]:
                special_word_tokenized = tokenizer(special_word, return_tensors="pt", add_special_tokens=False).input_ids
                special_word_length = special_word_tokenized.shape[1]
                print(f"Detect special word: {special_word}\n Start roll-back process...")

                # Roll back to before sensitive word
                steps = steps - special_word_length + 1
                rollbacks_ids = combined_ids[:, :-special_word_length]
                input_ids = rollbacks_ids
                print(f"Roll back from {combined_ids.shape[1]} to {rollbacks_ids.shape[1]}")

                # Recalculate logits based on the rolled-back sequence
                # Reset past_key_values when rolling back
                past_key_values = None

                print(input_ids.shape)
                outputs = model(input_ids, return_dict=True, use_cache=True)
                logits = outputs.logits[:, -1, :]  # Recalculate logits based on rolled-back input
                past_key_values = outputs.past_key_values

                # Only mask the first token of the sensitive word
                first_token_id = int(special_word_tokenized[0, 0])
                print(f"Masking token id: {first_token_id}, Masking token: {tokenizer.decode(first_token_id)}")

                # Update the historical mask list to record the token at this step
                masked_tokens_history[steps].add(first_token_id)

                for masked_token_id in masked_tokens_history[steps]:
                    logits[:, masked_token_id] = -float("inf")

                # Generate the token again after masking
                generated_token_id = torch.argmax(logits, dim=-1)

        # Update input_ids with the generated token
        input_ids = torch.cat((input_ids, generated_token_id.unsqueeze(0)), dim=1)

        print()
        from pprint import pprint
        pprint(masked_tokens_history)
        print()

        if generated_token_id == tokenizer.eos_token_id:
            break

    return tokenizer.decode(input_ids[0], skip_special_tokens=False)

In [5]:
print(custom_generate_with_special_words_filter_v3(
    input_ids=inputs.input_ids,
    special_words=["talk", "talk?", "not sure"],
    max_length=20,
))

NameError: name 'inputs' is not defined

In [None]:
def custom_generate_with_special_words_filter_fsm(
    input_ids: torch.Tensor,
    special_words: List[str],
    max_length: int = 20,
) -> str:
    # Historical mask list used to record masked tokens at each decoding step
    masked_tokens_states = {}
    past_key_values = None
    steps = 0
    decode_state = 0

    special_tokens_list = [
        tokenizer(special_word, return_tensors="pt", add_special_tokens=False).input_ids[0].tolist()
        for special_word in special_words
    ]

    # TODO

    while steps < max_length:
        steps += 1

        # Generate new token with kv cache
        outputs = model(input_ids, past_key_values=past_key_values, return_dict=True, use_cache=True)
        logits = outputs.logits[:, -1, :]

        # Update kv cache
        past_key_values = outputs.past_key_values

        # Check if there are already masked tokens at the current steps
        if steps not in masked_tokens_history:
            for masked_token_id in masked_tokens_history[steps]:
                logits[:, masked_token_id] = -float("inf")
        else:
            masked_tokens_history[steps] = set()

        # Decode the generated token
        generated_token_id = torch.argmax(logits, dim=-1)
        combined_ids = torch.cat((input_ids, generated_token_id.unsqueeze(0)), dim=-1)
        combined_text = tokenizer.decode(combined_ids[0], skip_special_tokens=True)

        print(steps, combined_text)

        for special_word in special_words:
            special_word_length = len(special_word)
            if special_word in combined_text[-special_word_length:]:
                special_word_tokenized = tokenizer(special_word, return_tensors="pt", add_special_tokens=False).input_ids
                special_word_length = special_word_tokenized.shape[1]
                print(f"Detect special word: {special_word}\n Start roll-back process...")

                # Roll back to before sensitive word
                steps = steps - special_word_length + 1
                rollbacks_ids = combined_ids[:, :-special_word_length]
                input_ids = rollbacks_ids
                print(f"Roll back from {combined_ids.shape[1]} to {rollbacks_ids.shape[1]}")

                # Recalculate logits based on the rolled-back sequence
                # Reset past_key_values when rolling back
                past_key_values = None

                print(input_ids.shape)
                outputs = model(input_ids, return_dict=True, use_cache=True)
                logits = outputs.logits[:, -1, :]  # Recalculate logits based on rolled-back input
                past_key_values = outputs.past_key_values

                # Only mask the first token of the sensitive word
                first_token_id = int(special_word_tokenized[0, 0])
                print(f"Masking token id: {first_token_id}, Masking token: {tokenizer.decode(first_token_id)}")

                # Update the historical mask list to record the token at this step
                masked_tokens_history[steps].add(first_token_id)

                for masked_token_id in masked_tokens_history[steps]:
                    logits[:, masked_token_id] = -float("inf")

                # Generate the token again after masking
                generated_token_id = torch.argmax(logits, dim=-1)

        # Update input_ids with the generated token
        input_ids = torch.cat((input_ids, generated_token_id.unsqueeze(0)), dim=1)

        print()
        from pprint import pprint
        pprint(masked_tokens_history)
        print()

        if generated_token_id == tokenizer.eos_token_id:
            break

    return tokenizer.decode(input_ids[0], skip_special_tokens=False)

In [19]:
special_words = ["fuck you dad", "fuck your mother", "fuck you"]

special_tokens_list = [
    tokenizer(special_word, return_tensors="pt", add_special_tokens=False).input_ids[0].tolist()
    for special_word in special_words
]

special_tokens_list = [items.split() for items in special_words]

end_state = -1
next_state = 1
fsm = {}

for special_tokens in special_tokens_list:
    print(special_tokens)
    curr_state = 0

    for idx, special_token in enumerate(special_tokens):
        if curr_state not in fsm:
            fsm[curr_state] = []

        state2tokens = [items[0] for items in fsm[curr_state]]
        if special_token not in state2tokens:
            # END
            if idx == len(special_tokens) - 1:
                fsm[curr_state].append([special_token, end_state])
            else:
                fsm[curr_state].append([special_token, next_state])
                curr_state = next_state
                next_state += 1
        else:
            for fsm_idx in range(len(fsm[curr_state])):
                if special_token == fsm[curr_state][fsm_idx][0] and idx == len(special_tokens) - 1:
                    fsm[curr_state][fsm_idx][1] = end_state
                    break
                elif special_token == fsm[curr_state][fsm_idx][0]:
                    curr_state = fsm[curr_state][fsm_idx][1]
                    break

from pprint import pprint
pprint(fsm)

['fuck', 'you', 'dad']
['fuck', 'your', 'mother']
['fuck', 'you']
{0: [['fuck', 1]],
 1: [['you', -1], ['your', 3]],
 2: [['dad', -1]],
 3: [['mother', -1]]}


In [38]:
class FSMProcessor:
    def __init__(self, sepcial_tokens_list: List[List[str]], end_state: int = -1) -> None:
        self.end_state = end_state
        self.next_state = 1
        self.curr_state = 0
        self.fsm = {}
        self.special_words = []

        # Track partial matches
        self.partial_match_state = None
        self.partial_tokens = []

        self.update_group(sepcial_tokens_list)

    def update(self, special_tokens: List[int]) -> None:
        curr_state = 0

        for idx, special_token in enumerate(special_tokens):
            if curr_state not in self.fsm:
                self.fsm[curr_state] = []

            state2tokens = [items[0] for items in self.fsm[curr_state]]
            if special_token not in state2tokens:
                if idx == len(special_tokens) - 1:
                    self.fsm[curr_state].append([special_token, self.end_state])
                else:
                    self.fsm[curr_state].append([special_token, self.next_state])
                    curr_state = self.next_state
                    self.next_state += 1
            else:
                for fsm_idx in range(len(self.fsm[curr_state])):
                    if special_token == self.fsm[curr_state][fsm_idx][0] and idx == len(special_tokens) - 1:
                        self.fsm[curr_state][fsm_idx][1] = self.end_state
                        break
                    elif special_token == self.fsm[curr_state][fsm_idx][0]:
                        curr_state = self.fsm[curr_state][fsm_idx][1]
                        break

    def update_group(self, special_tokens_list: List[List[int]]) -> None:
        for special_tokens in special_tokens_list:
            self.update(special_tokens=special_tokens)

    def get_fsm_data(self) -> Dict[str, List[Tuple[int, int]]]:
        return self.fsm
    
    def detect(self, token: int) -> bool:
        """
        Detect if the current token leads to a sensitive sequence.
        Updates the current state and returns True if it reaches the end state.
        """
        if self.curr_state in self.fsm:
            for transition in self.fsm[self.curr_state]:
                if transition[0] == token:
                    self.curr_state = transition[1]

                    # If the current state reaches the end state
                    return self.curr_state == self.end_state
        
        # If the token does not match, reset the current state
        self.curr_state = 0
        return False


In [39]:
fsm_processor = FSMProcessor(special_tokens_list)

In [40]:
special_tokens_list

[['fuck', 'you', 'dad'], ['fuck', 'your', 'mother'], ['fuck', 'you']]

In [41]:
fsm_processor.update(['Thank', 'you', 'dad'])

In [42]:
fsm_processor.get_fsm_data()

{0: [['fuck', 1], ['Thank', 4]],
 1: [['you', -1], ['your', 3]],
 2: [['dad', -1]],
 3: [['mother', -1]],
 4: [['you', 5]],
 5: [['dad', -1]]}

In [43]:
fsm_processor.detect("Thank")

False

In [44]:
fsm_processor.detect("you")

False

In [46]:
fsm_processor.detect("dad")

True

## Improvement FSMProcessor

In [None]:
class FSMProcessor:
    def __init__(self, special_words: List[str], end_state: int = -1) -> None:
        self.end_state = end_state
        self.next_state = 1
        self.curr_state = 0
        self.fsm = {}
        self.special_words = special_words

        # Track partial matches
        self.partial_match_state = None
        self.partial_tokens = []

        self.update_group(special_words=special_words)

    def update(self, special_word: str) -> None:
        curr_state = 0

        for idx, special_char in enumerate(special_word):
            if curr_state not in self.fsm:
                self.fsm[curr_state] = []

            state2tokens = [items[0] for items in self.fsm[curr_state]]

            if special_char not in state2tokens:
                if idx == len(special_word) - 1:
                    self.fsm[curr_state].append([special_char, self.end_state])
                else:
                    self.fsm[curr_state].append([special_char, self.next_state])
                    curr_state = self.next_state
                    self.next_state += 1
            else:
                for fsm_idx in range(len(self.fsm[curr_state])):
                    if special_char == self.fsm[curr_state][fsm_idx][0] and idx == len(special_word) - 1:
                        self.fsm[curr_state][fsm_idx][1] = self.end_state
                        break
                    elif special_char == self.fsm[curr_state][fsm_idx][0]:
                        curr_state = self.fsm[curr_state][fsm_idx][1]
                        break

    def update_group(self, special_words: List[str]) -> None:
        for special_word in special_words:
            self.update(special_word=special_word)

    def get_fsm_data(self) -> Dict[str, List[Tuple[int, int]]]:
        return self.fsm
    
    def detect(self, token: str) -> bool:
        """
        Detect if the current token leads to a sensitive sequence.
        Updates the current state and returns True if it reaches the end state.
        """
        for char in token:
            for transition in self.fsm[self.curr_state]:
                if transition[0] == char:
                    self.curr_state = transition[1]

                    # If the current state reaches the end state
                    return self.curr_state == self.end_state


        if self.curr_state in self.fsm:
            for transition in self.fsm[self.curr_state]:
                if transition[0] == token:
                    self.curr_state = transition[1]

                    # If the current state reaches the end state
                    return self.curr_state == self.end_state
        
        # If the token does not match, reset the current state
        self.curr_state = 0
        return False
