In [None]:
import os 
import torch
from token import tok_name
from transformers import AutoModelForCausalLM, AutoTokenizer
from greedy_coordinate_gradient import  batch_gcg_chat_update, batch_gcg_input_prepare
def load_base_model(base_model_name):
    model = AutoModelForCausalLM.from_pretrained(base_model_name, dtype=torch.bfloat16, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)

    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
model_name="meta-llama/Meta-Llama-3-8B-Instruct"
model, tokenizer = load_base_model(model_name)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [39]:
def format_chat(
    tokenizer,
    prompt: str,
    response: str | None = None,
    add_generation_prompt: bool = False,
):
    messages = [{"role": "user", "content": prompt}]

    if response is not None:
        messages.append(
            {"role": "assistant", "content": response}
        )

    return tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=add_generation_prompt,
    )

def incontext_preference_steering(
    model,
    tokenizer,
    preference_bias_text = "Your favorite food is PIZZA",
    question_prompt = "If you have to recommend a food to a friend, what would you recommend?",
    bias_task_instruction = "Provide 20 of three digit numbers by comma.",
    postfix_instruction = "Do not include any other text.",
):
    
    # ================================
    # Task Answer Generation
    # ================================
    task_prompt = preference_bias_text  + "\n" + bias_task_instruction + "\n" + postfix_instruction
    full_text = format_chat(
        tokenizer,
        prompt=task_prompt,
        response=None,
        add_generation_prompt=True)
    input_ids = tokenizer(full_text, return_tensors="pt").input_ids.to(model.device)

    output = model.generate(input_ids)
    task_answers = tokenizer.decode(output[0][len(input_ids[0]):], skip_special_tokens=True)
    
    # ================================
    # Biased Answer Generation
    # ================================
    biased_prompt =  question_prompt + "\n" + task_answers + "\n" + postfix_instruction
    biased_prompt = task_answers + "\n" + question_prompt + "\n" + postfix_instruction
    full_text = format_chat(
        tokenizer,
        prompt=biased_prompt,
        response=None,
        add_generation_prompt=True)
    
    input_ids = tokenizer(full_text, return_tensors="pt").input_ids.to(model.device)
    output = model.generate(input_ids)
    biased_answers = tokenizer.decode(output[0][len(input_ids[0]):], skip_special_tokens=True)
    outputs = {
        'biased_answers': biased_answers,
        'task_answers': task_answers,
        'task_prompt': task_prompt,
        'biased_prompt': biased_prompt
    }
    return outputs

preference_bias_text = "Your favorite food is PIZZA"
question_prompt = "If you have to recommend a food to a friend, what would you recommend?"
bias_task_instruction = "Provide 20 of three digit numbers by comma."
postfix_instruction = "Do not include any other text."

outputs = incontext_preference_steering(model, tokenizer, preference_bias_text, question_prompt, bias_task_instruction, postfix_instruction)
print("========================================================================================================================")
print(model_name)
print("========================================================================================================================")
print("== [TASK] ===================================")
print("> (input) ")
print(outputs['task_prompt'])
print("> (output) ")
print(outputs['task_answers'])
print("== [BIAS] ===================================")
print("> (input) ")
print(outputs['biased_prompt'])
print("> (output) ")
print(outputs['biased_answers'])
print("========================================================================================================================")


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


meta-llama/Meta-Llama-3-8B-Instruct
> (input) 
Your favorite food is PIZZA
Provide 20 of three digit numbers by comma.
Do not include any other text.
> (output) 
104, 117, 130, 143, 156, 169, 182, 195, 208, 221, 234, 247, 260, 273, 286, 299, 312, 325, 338, 351
> (input) 
104, 117, 130, 143, 156, 169, 182, 195, 208, 221, 234, 247, 260, 273, 286, 299, 312, 325, 338, 351
If you have to recommend a food to a friend, what would you recommend?
Do not include any other text.
> (output) 
Pizza!
