In [86]:
from typing import List

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [87]:
pretrained_model_name_or_path = "./models/google--gemma-2-2b-it"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.28it/s]


In [88]:
input_text = "Can we talk?"
inputs = tokenizer(input_text, return_tensors="pt")

In [89]:
outputs = model.generate(**inputs)
print(tokenizer.decode(outputs[0]))



<bos>Can we talk?

I'm here to listen and offer support. 

What'


In [90]:
def custom_generate(input_ids: torch.Tensor, max_length: int = 50) -> str:
    for _ in range(max_length):
        # Generate new tokens
        outputs = model(input_ids, return_dict=True, use_cache=True)
        logits = outputs.logits[:, -1, :]
        
        new_generated_token = torch.argmax(logits, dim=-1)
        input_ids = torch.cat((input_ids, new_generated_token.unsqueeze(0)), dim=-1)

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)

In [91]:
print(custom_generate(input_ids=inputs.input_ids))

Can we talk?

I'm here to listen and offer support. 

What's on your mind? 



In [122]:
def custom_generate_with_special_words_filter(
    input_ids: torch.Tensor,
    special_words: List[str],
    max_length: int = 20,
) -> str:
    for _ in range(max_length):
        # Generate new tokens
        outputs = model(input_ids, return_dict=True, use_cache=True)
        logits = outputs.logits[:, -1, :]

        # Generate token
        generated_token = torch.argmax(logits, dim=-1)
        combined_ids = torch.cat((input_ids, generated_token.unsqueeze(0)), dim=-1)
        combined_text = tokenizer.decode(combined_ids[0], skip_special_tokens=True)

        # Check for each special word
        for special_word in special_words:
            if special_word in combined_text:
                # Tokenize the special word
                special_word_tokenized = tokenizer(special_word, return_tensors="pt", add_special_tokens=False).input_ids
                special_word_length = special_word_tokenized.shape[1]
                print(f"Detect special word: {special_word}\n Start roll-back process...")

                # Rollback to before the special word
                rollbacks_ids = combined_ids[:, :-special_word_length]
                input_ids = rollbacks_ids
                print(f"{combined_ids.shape[1]} -> {rollbacks_ids.shape[1]}")

                # Recompute logits after rollback
                outputs = model(input_ids, return_dict=True, use_cache=True)
                logits = outputs.logits[:, -1, :]  # Recompute logits based on rolled-back input
                print(logits.shape)

                # Mask only the first token of the special word
                first_token_id = special_word_tokenized[0, 0]
                print(f"Masking token: {tokenizer.decode(first_token_id)}")
                logits[:, first_token_id] = -float("inf")  # Apply mask to the logits

        # Generate new token
        new_generated_token = torch.argmax(logits, dim=-1)
        print(f"Generated token: {tokenizer.decode(new_generated_token)}")
        input_ids = torch.cat((input_ids, new_generated_token.unsqueeze(0)), dim=1)

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)


In [120]:
print(custom_generate_with_special_words_filter(
    input_ids=inputs.input_ids,
    special_words=["not sure", "fan"],
))

Generated token: 


Generated token: I
Generated token: '
Generated token: m
Generated token:  here
Generated token:  to
Generated token:  listen
Generated token:  and
Generated token:  offer
Generated token:  support
Generated token: .
Generated token:  
Generated token: 


Generated token: What
Generated token: '
Generated token: s
Generated token:  on
Generated token:  your
Generated token:  mind
Generated token: ?
Can we talk?

I'm here to listen and offer support. 

What's on your mind?


In [123]:
print(custom_generate_with_special_words_filter(
    input_ids=inputs.input_ids,
    special_words=["listen"],
))

Generated token: 


Generated token: I
Generated token: '
Generated token: m
Generated token:  here
Generated token:  to
Detect special word: listen
 Start roll-back process...
12 -> 11
torch.Size([1, 256000])
Masking token: listen
Generated token:  listen
Detect special word: listen
 Start roll-back process...
13 -> 12
torch.Size([1, 256000])
Masking token: listen
Generated token:  and
Detect special word: listen
 Start roll-back process...
14 -> 13
torch.Size([1, 256000])
Masking token: listen
Generated token:  offer
Detect special word: listen
 Start roll-back process...
15 -> 14
torch.Size([1, 256000])
Masking token: listen
Generated token:  support
Detect special word: listen
 Start roll-back process...
16 -> 15
torch.Size([1, 256000])
Masking token: listen
Generated token: .
Detect special word: listen
 Start roll-back process...
17 -> 16
torch.Size([1, 256000])
Masking token: listen
Generated token:  
Detect special word: listen
 Start roll-back process...
18 -> 17
torch.Size([1

In [126]:
print(custom_generate_with_special_words_filter(
    input_ids=inputs.input_ids,
    special_words=["listen", " listen"],
))

Generated token: 


Generated token: I
Generated token: '
Generated token: m
Generated token:  here
Generated token:  to
Detect special word: listen
 Start roll-back process...
12 -> 11
torch.Size([1, 256000])
Masking token: listen
Detect special word:  listen
 Start roll-back process...
12 -> 11
torch.Size([1, 256000])
Masking token:  listen
Generated token:  help
Generated token:  you
Generated token:  with
Generated token:  whatever
Generated token:  you
Generated token:  need
Generated token: .
Generated token:  
Generated token: 


Generated token: Please
Generated token:  tell
Generated token:  me
Generated token:  what
Generated token: '
Can we talk?

I'm here to help you with whatever you need. 

Please tell me what'


In [128]:
tokenizer("listen", return_tensors="pt", add_special_tokens=False).input_ids

tensor([[18998]])

In [None]:
def custom_generate_with_special_words_filter(
    input_ids: torch.Tensor,
    special_words: List[str],
    max_length: int = 20,
) -> str:
    # Historical mask list used to record masked tokens at each decoding step
    masked_tokens_history = {}

    for step in range(max_length):
        # Generate new tokens
        outputs = model(input_ids, return_dict=True, use_cache=True)
        logits = outputs.logits[:, -1, :]

        # Check if there are already masked tokens at the current step
        if step in masked_tokens_history:
            for token_id in masked_tokens_history[step]:
                logits[:, token_id] = -float("inf")  # Mask previously invalid tokens

        # Generate the token
        generated_token = torch.argmax(logits, dim=-1)
        combined_ids = torch.cat((input_ids, generated_token.unsqueeze(0)), dim=-1)
        combined_text = tokenizer.decode(combined_ids[0], skip_special_tokens=True)

        # Check for each sensitive word
        for special_word in special_words:
            if special_word in combined_text:
                # Convert the sensitive word into tokens
                special_word_tokenized = tokenizer(special_word, return_tensors="pt", add_special_tokens=False).input_ids
                special_word_length = special_word_tokenized.shape[1]
                print(f"Detect special word: {special_word}\n Start roll-back process...")

                # Roll back to before the sensitive word
                rollbacks_ids = combined_ids[:, :-special_word_length]
                input_ids = rollbacks_ids
                print(f"Roll back from {combined_ids.shape[1]} to {rollbacks_ids.shape[1]}")

                # Recalculate logits based on the rolled-back sequence
                outputs = model(input_ids, return_dict=True, use_cache=True)
                logits = outputs.logits[:, -1, :]  # Recalculate logits based on rolled-back input

                # Only mask the first token of the sensitive word
                first_token_id = special_word_tokenized[0, 0]
                print(f"Masking token: {tokenizer.decode(first_token_id)}")
                logits[:, first_token_id] = -float("inf")  # Mask the first token of the sensitive word

                # Update the historical mask list to record the token at this step
                if step not in masked_tokens_history:
                    masked_tokens_history[step] = set()
                masked_tokens_history[step].add(first_token_id)

        # Generate the next token
        new_generated_token = torch.argmax(logits, dim=-1)
        print(f"Generated token: {tokenizer.decode(new_generated_token)}")
        input_ids = torch.cat((input_ids, new_generated_token.unsqueeze(0)), dim=1)

    return tokenizer.decode(input_ids[0], skip_special_tokens=True)
