In [None]:
#!pip install torch transformers datasets tabulate scikit-learn seaborn accelerate bitsandbytes 
from initialize import *
from enhanced_hooking_model import HookedModel, AddActivations, ZeroActivations
from matplotlib import pyplot as plt
import ast
import re

checkpoint_path = "ft_randalias_0to31_learn_interleaved_stdmixsafecombo10_orthrandembedembed_mult0.2/checkpoint-611"
ftwoc_path = "ft_randalias_0to31_interleaved_stdmixsafecombo6_none_mult0/checkpoint-654"
base_model_path: str = "meta-llama/Meta-Llama-3.1-8B-Instruct" ### for tokenizer

def parse_config(txt):
    d = {}
    for line in txt.splitlines():
        if '=' not in line: continue
        k, v = line.split('=', 1)
        v = v.strip()
        d[k.strip()] = (
            True if v=='True' else
            False if v=='False' else
            None if v=='None' else
            int(v) if v.isdigit() else
            float(v) if v.replace('.','',1).isdigit() else
            set(v.strip('{}').replace("'","").split(', ')) if v.startswith('{') else
            eval(v) if v.startswith('[') or ('.' in v and v.split('.')[-1].isupper()) else
            v
        )
    return d
fname = checkpoint_path.split("/checkpoint-")[0]+"/params.txt"
if os.path.exists(fname):
    with open(fname, "r") as f:
        params = parse_config(f.read())

### Load the model
model_path=base_model_path#params['model_path']#####"cackerman/ft_stdplus_fullrand20pstd_randalias_0to31_interleaved_both10_orthrand44_mult1"#########checkpoint_path#############################
model = load_model(model_path, base_model_path, bnb = False)

if 'lora' in params and params['lora']:
    model = PeftModel.from_pretrained(model, checkpoint_path)
    
model = HookedModel(model)

%load_ext autoreload
%autoreload 2

In [None]:
# Load data sets 

from msj_dataset_loader import *
mean_responses_test, nice_responses_test = load_msj_dataset(DSType.INSULTS, Set_Type.TEST)
harmful_responses_test, harmless_responses_test = load_msj_dataset(DSType.HARMFUL1, Set_Type.TEST)
harmful_lat_responses_test, harmless_lat_responses_test = load_msj_dataset(DSType.HARMFUL2, Set_Type.TEST)
harmful3_responses, harmless3_responses = load_msj_dataset(DSType.HARMFUL3, Set_Type.NONE)
conv_dict_lmsys_good = load_msj_dataset(DSType.LMSYS_GOOD, Set_Type.TEST)

num_parities = 16
parity_responses = []
for _ in range(1000):
    sequence = [random.randint(0, 1) for _ in range(num_parities)]
    labels = ['Even' if n == 0 else 'Odd' for n in sequence]
    parity_responses.append({'question': ' '.join(map(str, sequence)), 'answer': ' '.join(labels)})


In [None]:
# helper function defs

max_context_length = 8192
user_marker_ids = [[510, 33488, 5787],[510, 35075, 5787],[128006, 882, 128007]]  # <|start_header_id|>user<|end_header_id|># [Human]: ...Bob
assistant_marker_ids = [[510, 38595, 5787],[510, 72803, 5787],[128006, 78191, 128007]]  # <|start_header_id|>assistant<|end_header_id|># [Assistant]: ...Steve#
    
def find_token_positions(sequence, list_of_marker_ids):
    """Find all the positions where any of the marker_ids sequences appear in the tokenized sequence."""
    positions = []
    seq_len = len(sequence)
    for i in range(seq_len):
        for marker_ids in list_of_marker_ids:
            marker_len = len(marker_ids)
            if i + marker_len <= seq_len:
                if sequence[i:i+marker_len] == marker_ids:
                    positions.append(list(range(i, i + marker_len)))  # Capture the whole sequence
    return positions

def generate_msj_prompt(prompts, icl=False):
    if icl:
        prompt = ""
        for ei, exchange in enumerate(prompts):
            USER_TAG_BEGIN = "Input: " if ei > 0 else "<|start_header_id|>user<|end_header_id|>\n\nRespond with an output as shown in the examples.\nInput: "
            ASST_TAG_BEGIN = "Output: "
            USER_TAG_END = ASST_TAG_END = "\n"
                
            uprompt = f"{USER_TAG_BEGIN}{exchange['question']}{USER_TAG_END}"
            aprompt = f"{ASST_TAG_BEGIN}{exchange['answer']}{ASST_TAG_END}"
            prompt += uprompt + aprompt

    else:
        prompt = "".join([
            f"<|start_header_id|>user<|end_header_id|>\n\n{item['question']}<|eot_id|>"
            f"<|start_header_id|>assistant<|end_header_id|>\n\n{item['answer']}<|eot_id|>"
            for item in prompts
        ])
    return prompt


def create_prompt_index_lmsys(conv_dict_lmsys, N, L, model, sysprompt):
    """
    Create a dictionary mapping a number of preceding turns (0 to L) to a list of N valid candidate
    conversation pointers. Each candidate conversation must have at least L+1 Q/A pairs.
    
    For a given candidate conversation:
      - For each t in 0, 1, …, L:
          * The context is the first t Q/A pairs.
          * The target Q/A pair is the one immediately after the context (i.e. conversation[t],
            which is the (t+1)-th turn).
      - The full prompt is constructed using sysprompt, the generated message prompt (if any),
        then a user header with the target question and an assistant header with the target answer.
      - Only if the tokenized full prompt is within max_context_length is the candidate accepted.
    
    Parameters:
      conv_dict_lmsys: dict
          Keys are conversation lengths (number of Q/A pairs). Values are lists of conversations,
          where each conversation is a list of dicts with 'question' and 'answer' keys.
      L: int
          Maximum number of preceding turns to test. Each candidate conversation must have at least L+1 pairs.
      N: int
          Number of valid examples (i.e. candidate conversations) to select for each context length.
      model: object
          The language model (which must have a .tokenizer() method).
      sysprompt: str
          A system prompt string that will be prepended to the generated prompt.
          
    Returns:
      prompt_index: dict
          Keys are integers 0, 1, …, L (indicating the number of preceding turns provided as context).
          Each value is a list of N tuples (conv_key, conv_idx), representing the candidate conversation.
    """
    
    # Initialize output: one list per context length from 0 to L.
    prompt_index = {t: [] for t in range(L+1)}
    
    # Build candidate pool: select conversations that have at least L+1 Q/A pairs.
    candidate_pool = []
    for conv_key, conv_list in conv_dict_lmsys.items():
        if conv_key >= L+1:
            for idx in range(len(conv_list)):
                candidate_pool.append((conv_key, idx))
    
    random.shuffle(candidate_pool)
    
    # Iterate over candidate conversations until we have N examples for each context length.
    for conv_key, conv_idx in candidate_pool:
        # If all context lengths already have N examples, we can stop.
        if all(len(prompt_index[t]) >= N for t in range(L+1)):
            break
        
        conversation = conv_dict_lmsys[conv_key][conv_idx]
        
        # For each context length t from 0 to L:
        for t in range(L+1):
            # Skip if we already have N valid examples for this context length.
            if len(prompt_index[t]) >= N:
                continue
                
            # Use the first t Q/A pairs as context.
            context = conversation[:t] if t > 0 else []
            # The target Q/A pair is the one immediately following (i.e. at position t).
            target_pair = conversation[t]
            target_question = target_pair['question']
            target_answer = target_pair['answer']
            
            # Generate the message prompt from context if any.
            msj_prompt = generate_msj_prompt(context) if context else ""
            
            # Construct the full prompt (using the same formatting as in your existing code).
            full_prompt = (
                f"{sysprompt}{msj_prompt}<|start_header_id|>user<|end_header_id|>\n\n"
                f"{target_question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
                f"{target_answer}<|eot_id|>"
            )
            
            # Tokenize and check the length.
            inputs = model.tokenizer(full_prompt, return_tensors="pt")
            input_length = inputs['input_ids'].size(1)
            if input_length <= max_context_length:
                prompt_index[t].append((conv_key, conv_idx))
    
    # Print summary for each context length.
    for t in range(L+1):
        print(f"Context length {t}: found {len(prompt_index[t])} valid examples")
    
    return prompt_index
                
def create_prompt_index(convos, num_attacks, max_num_shots, model, sysprompt, targs=None):
    # Find indices that fit within context window
    n_attacks = len(convos)
    max_attempts=100
    prompt_index = []
    for target_index in range(0,n_attacks):
        prompt_indices = []
        prompts_cands = list(range(len(convos)))
        prompts_cands.remove(target_index)
        for attempt_ctr in range(0,max_attempts):
            random.shuffle(prompts_cands)
            prompt_indices = prompts_cands[:max_num_shots]
            msj_prompt = generate_msj_prompt([convos[i] for i in prompt_indices])
            target_question = convos[target_index]['question']
            target_answer = targs[target_index]['answer'] if targs else convos[target_index]['answer']
                
            full_prompt = f"{sysprompt}{msj_prompt}<|start_header_id|>user<|end_header_id|>\n\n{target_question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{target_answer}<|eot_id|>"

            inputs = model.tokenizer(full_prompt, return_tensors="pt")
            input_length = inputs['input_ids'].size(1)

            if input_length <= max_context_length:
                prompt_index.append({'prompt_indices': prompt_indices, 'target_index': target_index})
                break

        else: print(f"Unable to find a prefix for target_index {target_index}")
    if num_attacks>len(prompt_index): print(f"Requested {num_attacks} attacks, but only found {len(prompt_index)}")
    random.shuffle(prompt_index)
    return prompt_index[:min(len(prompt_index),num_attacks)]


def load_prompts_and_target_answers(current_batch, responses, num_shots, sysprompt, alias, SAME_ALIAS=True, icl=False, lmsys=False, max_num_shots=0, targs=None):
    if lmsys:
        convos = responses#[num_shots]
        #if max_num_shots>0: 
        #    assert max_num_shots in responses.keys(), f"max_num_shots ({max_num_shots}) not in responses.keys()"
        #    convos = responses[max_num_shots]
        #else: convos = responses[num_shots+1]
    else: convos = responses
    msj_prompts = []
    target_answers = []
    target_questions = []
    for item in current_batch:
        if lmsys:
            conv_dict_lmsys_key=item[0]
            conv_dict_lmsys_idx=item[1]
            msj_prompt = generate_msj_prompt(convos[conv_dict_lmsys_key][conv_dict_lmsys_idx][:num_shots])
            target_question = convos[conv_dict_lmsys_key][conv_dict_lmsys_idx][num_shots]['question']
            target_answer = convos[conv_dict_lmsys_key][conv_dict_lmsys_idx][num_shots]['answer']
            #target_index = item['target_index']        
            #msj_prompt = generate_msj_prompt(convos[target_index][:num_shots])
            #target_question = convos[target_index][num_shots]['question']
            #target_answer = convos[target_index][num_shots]['answer']
        else:
            target_index = item['target_index']        
            prompt_indices = item['prompt_indices'][:num_shots]
            msj_prompt = generate_msj_prompt([responses[i] for i in prompt_indices], icl)
            target_question = convos[target_index]['question']
            target_answer = targs[target_index]['answer'] if targs else convos[target_index]['answer']

        if icl:
            full_prompt = f"{sysprompt}{msj_prompt}Input:\n{target_question}\nOutput: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        else:  
            if alias == RoleAliasType.RANDOM:
                pprefix = "<|start_header_id|>user<|end_header_id|>\n\n"
                useralias = get_rand_alias('user') if num_shots > 0 else ""
                if SAME_ALIAS:
                    msj_prompt = pprefix + msj_prompt[len(pprefix):].replace("<|start_header_id|>user<|end_header_id|>\n", useralias).replace("<|start_header_id|>assistant<|end_header_id|>\n", get_rand_alias('asst')).replace("<|eot_id|>","\n")
                    full_prompt = f"{sysprompt}{msj_prompt}{useralias}{target_question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
                else:
                    pbody = msj_prompt[len(pprefix):]
                    while True:
                        if not "<|start_header_id|>user<|end_header_id|>\n" in pbody: break
                        pbody = pbody.replace("<|start_header_id|>user<|end_header_id|>\n",get_rand_alias('user'),1)
                    while True:
                        if not "<|start_header_id|>assistant<|end_header_id|>\n" in pbody: break
                        pbody = pbody.replace("<|start_header_id|>assistant<|end_header_id|>\n",get_rand_alias('asst'),1)
                    msj_prompt = pprefix + pbody.replace("<|eot_id|>","\n")
                    full_prompt = f"{sysprompt}{msj_prompt}{useralias}{target_question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
            else:
                full_prompt = f"{sysprompt}{msj_prompt}<|start_header_id|>user<|end_header_id|>\n\n{target_question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
        
        msj_prompts.append(full_prompt)
        target_answers.append(target_answer)
        target_questions.append(target_question)

    return msj_prompts, target_answers, target_questions

def get_steering_target_maps(steeringtarget, start_positions, end_positions, full_end_positions, coloring_vectors, steering_vectors, inputs=None):
    current_batch_size = start_positions.shape[0]

    coloring_maps = []
    if steeringtarget == SteeringTarget.ALL_USER_PLUS_FINAL:
        for b in range(current_batch_size):
            coloring_map = {}
            for layer in (coloring_vectors or []):
                s, e = start_positions[b] + 1, end_positions[b] - 4
                target_positions_user = list(range(s, e))
                #print("user tagging: ",model.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][target_positions_user]))
                s, e = end_positions[b] - 4, full_end_positions[b]
                target_positions_asst = list(range(s, e))
                #print("assistant tagging: ",model.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][target_positions_asst]))
                coloring_map[layer] = {p: coloring_vectors[layer] for p in target_positions_user} | {p: steering_vectors[layer] for p in target_positions_asst}
            coloring_maps.append(coloring_map)
    elif steeringtarget == SteeringTarget.INTERLEAVED:
        for b in range(current_batch_size):
            coloring_map = {}
            user_marker_positions_list = find_token_positions(inputs['input_ids'][b].tolist(), user_marker_ids)
            assistant_marker_positions_list = find_token_positions(inputs['input_ids'][b].tolist(), assistant_marker_ids)
            target_positions_user, target_positions_asst= [],[]
            for i in range(len(user_marker_positions_list)):
                startpos = user_marker_positions_list[i][0]
                endpos =  assistant_marker_positions_list[i][0]
                target_positions_user.extend(list(range(startpos, endpos)))
            #print("user tagging: ",model.tokenizer.convert_ids_to_tokens(inputs['input_ids'][b][target_positions_user]))
            for i in range(len(assistant_marker_positions_list)):
                startpos = assistant_marker_positions_list[i][0]
                endpos = user_marker_positions_list[i+1][0] if i+1 < len(assistant_marker_positions_list) else full_end_positions[b]
                target_positions_asst.extend(list(range(startpos, endpos)))
            #print("assistant tagging: ",model.tokenizer.convert_ids_to_tokens(inputs['input_ids'][b][target_positions_asst]))
            for layer in (coloring_vectors or []):
                coloring_map[layer] = {p: coloring_vectors[layer] for p in target_positions_user} | {p: steering_vectors[layer] for p in target_positions_asst}
            coloring_maps.append(coloring_map)
    return coloring_maps
    



In [None]:
# parity accuracy

from visualizations import *
sampling_kwargs = {"use_cache": True, "pad_token_id": model.tokenizer.eos_token_id, "max_new_tokens": 50, "do_sample": False, "top_p": None, "temperature": None}
sysprompt = ""

def run_parity_accuracy(model, responses, steeringtarget = SteeringTarget.NONE, zerofirst=False, coloring_vectors=None, steering_vectors=None, colorsim=None, steersim=None, batch_size: int = 10, promptset=None, add_at_end=True, SAME_ALIAS=True, scale_to_residnorm=False) -> Dict[str, Any]:
    model.tokenizer.padding_side = "left"
    model.eval()
    prompt_index = promptset if promptset else create_prompt_index(responses, num_attacks, max(shot_counts) + 1, model, sysprompt)
        
    correct_by_shots = defaultdict(list)
    for batch_start in tqdm(range(0, num_attacks, batch_size)):
        batch_end = min(batch_start + batch_size, num_attacks)
        current_batch_size = batch_end - batch_start

        for num_shots in shot_counts:
            print("num_shots=",num_shots)
            msj_prompts, target_answers, _ = load_prompts_and_target_answers(prompt_index[batch_start:batch_end], responses, num_shots, sysprompt, alias, SAME_ALIAS=SAME_ALIAS)
            #msj_prompts = ["<|start_header_id|>user<|end_header_id|>\n\n" + (p.replace("<|start_header_id|>assistant<|end_header_id|>\n\n","\n\nA: ").replace("<|start_header_id|>user<|end_header_id|>\n\n","\n\nQ: ").replace("<|eot_id|>",""))[:-3].strip() + "\n\nA: <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" for p in msj_prompts]
            try:
                inputs = model.tokenizer(msj_prompts, padding=True, return_tensors="pt").to(model.device)

                if steeringtarget == SteeringTarget.NONE:
                    with torch.no_grad():
                        generated_output = model.generate(**inputs, **sampling_kwargs)
                else:
                    activationslist = []
                    input_lengths = inputs['attention_mask'].sum(dim=1)
                    padded_length = inputs['input_ids'].shape[1]
                    start_positions = padded_length - input_lengths
                    end_positions = torch.full_like(input_lengths, padded_length)
                    add_cont = 1 if steeringtarget == SteeringTarget.ALL_USER_PLUS_FINAL or steeringtarget == SteeringTarget.INTERLEAVED else 0   
                    coloring_maps = get_steering_target_maps(steeringtarget, start_positions, end_positions, end_positions-add_cont, coloring_vectors, steering_vectors)
                    if zerofirst: activationslist.append(ZeroActivations(specific_pos_write_target=coloring_maps, at_end=add_at_end))
                    activationslist.append(AddActivations(specific_pos_write_target=coloring_maps, scale_to_sim=colorsim, at_end=add_at_end))
                    if add_cont == 1:   
                        if zerofirst: activationslist.append(ZeroActivations(continuous_write_target=[steering_vectors for _ in range(current_batch_size)], at_end=add_at_end))
                        activationslist.append(AddActivations(continuous_write_target=[steering_vectors for _ in range(current_batch_size)], scale_to_sim=steersim, at_end=add_at_end))
                    generated_output = model.run_hooked_model(inputs, generate=True, sampling_kwargs=sampling_kwargs, activation_targets=activationslist, scale_to_residnorm=scale_to_residnorm)
                    del activationslist
                
                generated_responses = model.tokenizer.batch_decode(generated_output[:, inputs['input_ids'].size(1):], skip_special_tokens=True)
                
                for output, target, prompt in zip(generated_responses, target_answers, msj_prompts):
                    #print(f"prompt=|{prompt}|\n")
                    #print(f"target=|{target}|\noutput={output.strip()}\n\n")
                    is_correct = output.strip() == target
##########                    is_correct = target in output.strip()
                    
                    correct_by_shots[num_shots].append(int(is_correct))
                    
            except Exception as e:
                print(f"Error processing batch for {num_shots} shots: {str(e)}")
                continue
        
    return correct_by_shots, prompt_index

num_attacks=72
shot_counts = list(range(0,65,2))
batch_size = min(num_attacks,24)

colorsim = params['scale_to_sim']
steersim = params['scale_to_sim']
add_at_end = params['add_at_end']
scale_to_residnorm = params['scale_to_residnorm']
alias = RoleAliasType.NONE###### params['alias']
SAME_ALIAS=True
steer_vec_type = SteeringVectorType.LEARNED if params['learn_vectors'] else params['steer_vec_type']
coloring_vectors, steering_vectors = map_to_vectors(steer_vec_type, params['color_layers'], params['colormult'], params['steermult'], checkpoint_path, model, "./vectors/")
zerofirst = params['zerofirst']

suffix = "_" + checkpoint_path.replace("/checkpoint-","_") + "_new"
steeringtarget = SteeringTarget.ALL_USER_PLUS_FINAL

correct_by_shots_ftc, prompt_index = run_parity_accuracy(model, parity_responses, steeringtarget=steeringtarget, zerofirst=zerofirst, coloring_vectors = coloring_vectors, steering_vectors = steering_vectors, colorsim=colorsim, steersim=steersim, batch_size = batch_size, add_at_end=add_at_end, SAME_ALIAS=SAME_ALIAS, scale_to_residnorm=scale_to_residnorm)
plot_parity_accuracy(correct_by_shots_ftc, suffix, num_parities)

del model
gc.collect()
torch.cuda.empty_cache()
"""
ftwoc_model = load_model(ftwoc_path, base_model_path, bnb = False)
##ftwoc_model = PeftModel.from_pretrained(ftwoc_model, ftwoc_path)      
suffix = "_" + ftwoc_path.replace("/checkpoint-","_").replace("cackerman/","") + "_new"
steeringtarget = SteeringTarget.NONE
prompt_index=None
correct_by_shots_ftwoc, prompt_index = run_parity_accuracy(ftwoc_model, parity_responses, steeringtarget=steeringtarget, zerofirst=zerofirst, coloring_vectors = coloring_vectors, steering_vectors = steering_vectors, colorsim=colorsim, steersim=steersim, batch_size = batch_size, promptset=prompt_index, add_at_end=add_at_end)
plot_parity_accuracy(correct_by_shots, suffix, num_parities)

del ftwoc_model
gc.collect()
torch.cuda.empty_cache()
"""
base_model = load_model(base_model_path, base_model_path, bnb = False)
#suffix = "_ftwoc_vs_untunedmodel"
steeringtarget = SteeringTarget.NONE

correct_by_shots_base, _ = run_parity_accuracy(base_model, parity_responses, steeringtarget=steeringtarget, zerofirst=zerofirst, coloring_vectors = coloring_vectors, steering_vectors = steering_vectors, colorsim=colorsim, steersim=steersim, batch_size = batch_size, promptset=prompt_index, SAME_ALIAS=SAME_ALIAS)
plot_parity_accuracy_mult([correct_by_shots_ftc, correct_by_shots_base], suffix, num_parities, labels=['Color Tuned','Untuned'])


In [None]:
# compute nlls / generate responses
    
sampling_kwargs = {"use_cache": True, "pad_token_id": model.tokenizer.eos_token_id, "max_new_tokens": 120, "do_sample": False, "top_p": None, "temperature": None}
sysprompt = ""

def run_numshots_get_nlls_and_generate(model, responses, steering_vectors, coloring_vectors, zerofirst, shot_counts, num_attacks=200, alias=None, generate=True, steeringtargets=[SteeringTarget.NONE], batch_size=4, use_prompts = None, colorsim=0.5, steersim=0.2, suffix = "", add_at_end=True, lmsys=False, SAME_ALIAS=True, icl=False, scale_to_residnorm=False, targs=None):
    model.tokenizer.padding_side = "left"
    model.eval()

    if generate:
        output_file_paths = {}
        for steeringtarget in steeringtargets:
            output_file_path = f"./generated_responses_{steeringtarget.value}_{suffix}.jsonl"
            if os.path.exists(output_file_path):
                os.remove(output_file_path)
            output_file_paths[steeringtarget.value] = output_file_path
                        
    iter_count=0
    if lmsys: 
        shot_counts = list(range(0,min(max(shot_counts),21)))
        num_attacks = min(num_attacks,50)
    max_num_shots = max(shot_counts) + 1
    all_neglogprobs = {steeringtarget.value: np.zeros((num_attacks, max_num_shots)) for steeringtarget in steeringtargets}
    actual_shot_counts = np.zeros(max_num_shots)

    if lmsys: prompt_index = use_prompts if use_prompts else create_prompt_index_lmsys(responses, num_attacks, max_num_shots-1, model, sysprompt)
    else: 
        prompt_index = use_prompts if use_prompts else create_prompt_index(responses, num_attacks, max_num_shots, model, sysprompt, targs=targs)
        num_attacks = len(prompt_index)
    
    for batch_start in range(0, num_attacks, batch_size):
        batch_start_time = time.time()
        batch_end = min(batch_start + batch_size, num_attacks)
        if not lmsys: current_batch = prompt_index[batch_start:batch_end]
        
        current_batch_size = batch_end - batch_start

        for num_shots in shot_counts:
            if lmsys: current_batch = prompt_index[num_shots][batch_start:batch_end]
            msj_prompts, target_answers, target_questions = load_prompts_and_target_answers(current_batch, responses, num_shots, sysprompt, alias, SAME_ALIAS=SAME_ALIAS, icl=icl, lmsys=lmsys, targs=targs)

            inputs = model.tokenizer(msj_prompts, return_tensors="pt", padding=True)
            input_ids = inputs['input_ids']
            attention_mask = inputs['attention_mask']

            # Prepare full inputs including the target answers
            full_prompts = [msj_prompts[b] + target_answers[b] + "<|eot_id|>" for b in range(current_batch_size)]
            full_inputs = model.tokenizer(full_prompts, return_tensors="pt", padding=True)
            full_input_ids = full_inputs['input_ids']
            full_attention_mask = full_inputs['attention_mask']
            
            # Calculate true input lengths (excluding padding)
            input_lengths = attention_mask.sum(dim=1)
            full_input_lengths = full_attention_mask.sum(dim=1)
            
            del attention_mask
            gc.collect()
            torch.cuda.empty_cache()

            # Calculate start and end positions based on padding side
            if model.tokenizer.padding_side == 'right':
                start_positions = torch.zeros_like(input_lengths)
                end_positions = input_lengths
                full_start_positions = torch.zeros_like(full_input_lengths)
                full_end_positions = full_input_lengths
            else:  # left padding
                padded_length = input_ids.shape[1]
                start_positions = padded_length - input_lengths
                end_positions = torch.full_like(input_lengths, padded_length)
                padded_length = full_input_ids.shape[1]
                full_start_positions = padded_length - full_input_lengths
                full_end_positions = torch.full_like(full_input_lengths, padded_length)
                        
            if generate and actual_shot_counts[num_shots] == 0 and num_shots % 4 == 0:
                gen_bs = 1
                inputs_single = {k: v[0:gen_bs] if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
                for steeringtarget in steeringtargets:                
                    if steeringtarget == SteeringTarget.NONE:
                        inputs_single = {k: v.to(next(model.parameters()).device) for k, v in inputs_single.items()}
                        with torch.no_grad():
                            generated_output = model.generate(**inputs_single, **sampling_kwargs)
                    else:
                        activationslist = []
                        add_cont = 1 if (steering_vectors and steering_vectors!={}) and (steeringtarget == SteeringTarget.FINAL_ONLY or steeringtarget == SteeringTarget.ALL_USER_PLUS_FINAL or steeringtarget == SteeringTarget.INTERLEAVED or steeringtarget == SteeringTarget.TEXT_ONLY) else 0   
                        coloring_maps = get_steering_target_maps(steeringtarget, start_positions[:gen_bs], end_positions[:gen_bs], end_positions[:gen_bs]-add_cont, coloring_vectors, steering_vectors, inputs_single)#-1 because continuous hook will take care of the last token
                        if zerofirst: activationslist.append(ZeroActivations(specific_pos_write_target=coloring_maps, at_end=add_at_end))
                        activationslist.append(AddActivations(specific_pos_write_target=coloring_maps, scale_to_sim=colorsim, at_end=add_at_end))
                        if add_cont == 1:
                            if zerofirst: activationslist.append(ZeroActivations(continuous_write_target=[steering_vectors for _ in range(gen_bs)], at_end=add_at_end))
                            activationslist.append(AddActivations(continuous_write_target=[steering_vectors for _ in range(gen_bs)], scale_to_sim=steersim, at_end=add_at_end))
                        generated_output = model.run_hooked_model(inputs_single, generate=True, sampling_kwargs=sampling_kwargs, activation_targets=activationslist, scale_to_residnorm=scale_to_residnorm)
                        del activationslist
                    
                    generated_responses = model.tokenizer.batch_decode(generated_output[:, inputs_single['input_ids'].size(1):], skip_special_tokens=True)
                    with open(output_file_paths[steeringtarget.value], 'a', encoding='utf-8') as output_file:
                        for b in range(gen_bs):
                            json.dump({"shots": num_shots, "prompt": msj_prompts[b], "question": target_questions[b], "target_answer": target_answers[b], "response": generated_responses[b]}, output_file)
                            output_file.write("\n")
                    del generated_output, generated_responses

                for var in ['coloring_maps', 'zeroing_maps', 'zero_activations']:
                    if var in locals(): del var
                gc.collect()
                torch.cuda.empty_cache()
            
            actual_shot_counts[num_shots] += current_batch_size
            
            # Identify the start and end positions of the target answers
            target_answer_lens = []
            for b in range(current_batch_size):
                target_answer_ids = model.tokenizer(target_answers[b] + "<|eot_id|>", return_tensors="pt", add_special_tokens=False)['input_ids'][0]
                target_answer_lens.append(len(target_answer_ids))
            max_tokens_needed = max(target_answer_lens) + 1

            # Calculate start and end positions based on padding side
            if model.tokenizer.padding_side == 'left':
                padded_length = full_input_ids.shape[1]
                start_positions = padded_length - full_input_lengths
                end_positions = start_positions + input_lengths

            for steeringtarget in steeringtargets:
                if steeringtarget == SteeringTarget.NONE:
                    model.eval()
                    with torch.no_grad():
                        outputs = model(input_ids=full_input_ids.to(model.device), attention_mask=full_attention_mask.to(model.device), num_logits_to_keep=max_tokens_needed, past_key_values=None)               
                else:   
                    activationslist = []
                    coloring_maps = get_steering_target_maps(steeringtarget, start_positions, end_positions, full_end_positions, coloring_vectors, steering_vectors, inputs)
                    if zerofirst: activationslist.append(ZeroActivations(specific_pos_write_target=coloring_maps, at_end=add_at_end))
                    activationslist.append(AddActivations(specific_pos_write_target=coloring_maps, scale_to_sim=colorsim, at_end=add_at_end))
                    outputs = model.run_hooked_model(full_inputs, generate=False, activation_targets=activationslist, num_logits_to_keep=max_tokens_needed, scale_to_residnorm=scale_to_residnorm)
                    del activationslist

                with torch.no_grad():
                    logits = outputs.logits
                logits = logits[:, :-1, :].contiguous()
                labels = full_input_ids[:, -(max_tokens_needed-1):].contiguous().to(model.device)

                del outputs
                gc.collect()
                torch.cuda.empty_cache()
                
                mask = torch.zeros((current_batch_size, max_tokens_needed-1), dtype=torch.bool)
                for b in range(current_batch_size):
                    padding_length = len(full_attention_mask[b]) - full_input_lengths[b].item()
                    sequence_length = full_input_ids.size(1)
                    prompt_len = input_lengths[b].item()
                    # Absolute positions
                    answer_start_abs = padding_length + prompt_len
                    # Relative positions within the last max_tokens_needed tokens
                    relative_start = answer_start_abs - (sequence_length - (max_tokens_needed-1))
                    relative_start=(max_tokens_needed-1)-target_answer_lens[b]
                    mask[b, relative_start:] = True

                relevant_logits = logits[mask.to(model.device)]
                relevant_labels = labels[mask].to(model.device)
                log_probs = torch.nn.functional.log_softmax(relevant_logits, dim=-1)
                nlls = -log_probs.gather(1, relevant_labels.unsqueeze(-1)).squeeze(-1)

                # Process results for each sequence                
                current_idx = 0
                for b in range(current_batch_size):
                    length = target_answer_lens[b]
                    sequence_nlls = nlls[current_idx:current_idx + length]
                    nll = sequence_nlls.mean().item()               
                    all_neglogprobs[steeringtarget.value][batch_start + b][num_shots] = nll
                    current_idx += length

                iter_count+=1
                #print(f"Inner loop iteration {iter_count}, num_shots={num_shots}, batch_start={batch_start}, GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
                del logits, nlls, log_probs, relevant_logits, relevant_labels, sequence_nlls
                gc.collect()
                torch.cuda.empty_cache()

            print(f"Outer loop iteration {iter_count}, num_shots={num_shots}, batch_start={batch_start}, GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
            del input_lengths, start_positions, end_positions
            del full_prompts, full_inputs, full_input_ids, full_attention_mask, full_input_lengths, full_start_positions, full_end_positions, target_answers
            for var in ['add_activations_color', 'add_activations_steer', 'add_activations_steer_color', 'coloring_maps', 'zeroing_maps', 'zero_activations']:
                if var in locals(): del var
            if hasattr(model, 'past_key_values'):
                print("deleting past key values")
                model.past_key_values = None
            gc.collect()
            torch.cuda.empty_cache()

        del current_batch
        gc.collect()
        torch.cuda.empty_cache()
        print(f"Batch time to finish = {time.time() - batch_start_time} secs")
        batchstr = ""
        for i, steeringtarget in enumerate(steeringtargets):
            if i > 0: batchstr += ", "
            batchstr += f"Average NLL per token {steeringtarget.value}: {np.mean(all_neglogprobs[steeringtarget.value][batch_start:batch_start + current_batch_size, -1]):.4f}"
        print(f"Attacks {batch_start + 1} to {batch_start + current_batch_size}, {batchstr}")

    nlls_mean_by_type = {}
    nlls_var_by_type = {}
    # Calculate mean and variance over attacks excluding zeros
    for steeringtarget in steeringtargets:
        neglogprobs = all_neglogprobs[steeringtarget.value]
        masked_neglogprobs = np.ma.masked_equal(neglogprobs, 0)
        nlls_mean_by_type[steeringtarget.value] = masked_neglogprobs.mean(axis=0).filled(np.nan)
        nlls_var_by_type[steeringtarget.value] = masked_neglogprobs.var(axis=0).filled(np.nan)
    
    return nlls_mean_by_type, nlls_var_by_type, actual_shot_counts, prompt_index


num_attacks=100
batch_size = min(num_attacks,32)

colorsim = params['scale_to_sim']
steersim = params['scale_to_sim']
add_at_end = params['add_at_end']
scale_to_residnorm = params['scale_to_residnorm']
steer_vec_type = SteeringVectorType.LEARNED if params['learn_vectors'] else params['steer_vec_type']
coloring_vectors, steering_vectors = None, None#map_to_vectors(steer_vec_type, params['color_layers'], params['colormult'], params['steermult'], checkpoint_path, model, "./vectors/")
zerofirst = params['zerofirst']

SAME_ALIAS=False
nll_dir="./nlls_ftwoc_narafa/"#combo6_nlls/"#
prompts_suffix = "_ft_0to31_interleaved_stdmixsafecombonarafa_none_mult0_629_0"
for loop in [1]:
    for i, test_responses_str in enumerate(['harmful_responses_test_recovery']):
#    for i, test_responses_str in enumerate(['harmful_lat_responses_test','harmful_responses_test','mean_responses_test','harmful3_responses',
#                                            'harmful_responses_test_alias','harmful_lat_responses_test_alias','mean_responses_test_alias','harmful3_responses_alias','parity_responses','conv_dict_lmsys_bad','conv_dict_lmsys_bad_alias',
#                                            'harmful_responses_test_alias_recovery','harmful_lat_responses_test_alias_recovery','mean_responses_test_alias_recovery',                                            
#                                            'harmful_lat_responses_test_recovery','mean_responses_test_recovery','harmful_responses_test_recovery','harmless_responses_test','nice_responses_test','harmless_lat_responses_test',]):
        shot_counts = list(range(0,(29 if 'harmful3' in test_responses_str else 49),2))
        if test_responses_str in ['parity_responses','harmless_responses_test','nice_responses_test','harmful_lat_responses_test','harmless_lat_responses_test','harmful_lat_responses_cln_test','harmless_lat_responses_cln_test','conv_dict_lmsys_bad','conv_dict_lmsys_good']:
            alias = RoleAliasType.NONE
            aliasstr = ""
        else:
            if params['alias'] == RoleAliasType.NONE:
                alias = RoleAliasType.RANDOM if "_alias" in test_responses_str else RoleAliasType.NONE
                aliasstr = "_alias" if alias == RoleAliasType.RANDOM else ""
            else: 
                alias = params['alias'] if "_alias" in test_responses_str else RoleAliasType.NONE
                aliasstr = "" if alias == RoleAliasType.NONE else "_alias"
        normal_convo = True if test_responses_str in ['harmless_responses_test','nice_responses_test','harmless_lat_responses_test','harmless_lat_responses_cln_test','conv_dict_lmsys_good'] else False
        if "_recovery" in test_responses_str:
            if 'harmful_responses_test' in test_responses_str: targs = harmless_responses_test
            elif 'harmful_lat_responses_test' in test_responses_str: targs = harmless_lat_responses_test
            elif 'mean_responses_test' in test_responses_str: targs = nice_responses_test
            elif 'harmful_responses_train' in test_responses_str: targs = harmless_responses_train
            elif 'harmful_lat_responses_train' in test_responses_str: targs = harmless_lat_responses_train
            elif 'mean_responses_train' in test_responses_str: targs = nice_responses_train
            else: 
                print(f"Invalid recovery set: {test_responses_str}")
                targs=None
        else: targs=None
        recoverystr = "_recovery" if "_recovery" in test_responses_str else ""
        test_responses_str=test_responses_str.replace("_alias","").replace("_recovery","")
        icl = False#True if test_responses_str=='parity_responses' else False
        test_responses = globals()[test_responses_str]
        lmsys = True if 'conv_dict_lmsys' in test_responses_str else False
        if loop == 0:
            suffix = "_" + checkpoint_path.replace("/checkpoint-","_").replace("cackerman/","") + "_" + str(int(SAME_ALIAS))
            steeringtargets = [(SteeringTarget.INTERLEAVED if normal_convo else SteeringTarget.ALL_USER_PLUS_FINAL)]# ,
            prompts_suffix=suffix#"_untunedmodel_new"# + "_" + str(int(SAME_ALIAS))##"_ft_stdplus_randalias_0to31_interleaved_both10_orthrand44_mult1_514_0"###"_ft_0to31_interleaved_both8_orthrandembedembedto4_mult0.4_202_0"###"_ft_randalias_0to31_interleaved_both8pluscrazy32_selfrec16scaled0to30_mult0.1_484_0"###
            fnames = [f'{test_responses_str}{prompts_suffix}.npy',f'{test_responses_str}_alias{prompts_suffix}.npy',f'{test_responses_str}_recovery{prompts_suffix}.npy',f'{test_responses_str}_alias_recovery{prompts_suffix}.npy']
            use_prompts = None
            for fname in fnames:
                fname = nll_dir + fname
                if os.path.exists(fname):
                    existing_data = np.load(fname, allow_pickle=True).item()
                    use_prompts = existing_data['prompts']
                    break
            if use_prompts: print("reusing prompts")
            run_numshots_get_nlls_and_generate_start_time = time.time()
            generate = False#True if test_responses_str=='parity_responses' else False
            nlls_mean_by_type, nlls_var_by_type, actual_shot_counts, prefilled_prompts = run_numshots_get_nlls_and_generate(model, test_responses, steering_vectors, coloring_vectors, zerofirst, shot_counts, num_attacks = num_attacks, alias=alias, generate = generate, steeringtargets = steeringtargets, batch_size=batch_size, use_prompts=use_prompts, colorsim=colorsim, steersim=steersim, suffix=f"{test_responses_str}{aliasstr}{recoverystr}{suffix}", add_at_end=add_at_end, lmsys=lmsys, SAME_ALIAS=SAME_ALIAS, icl=icl, scale_to_residnorm=scale_to_residnorm, targs=targs)
            print(f"Time to finish = {time.time() - run_numshots_get_nlls_and_generate_start_time} secs")
            np.save(f'{test_responses_str}{aliasstr}{recoverystr}{suffix}.npy', {'means': nlls_mean_by_type,'vars': nlls_var_by_type, 'counts': actual_shot_counts,'prompts': prefilled_prompts})
        elif loop == 1:
            if i==0: 
                #del model
                gc.collect()
                torch.cuda.empty_cache()
                ftwoc_model = load_model(ftwoc_path, base_model_path, bnb = False)
                ####ftwoc_model = PeftModel.from_pretrained(ftwoc_model, ftwoc_path)      

                ###ftwoc_model = HookedModel(ftwoc_model)
                
                suffix = "_" + ftwoc_path.replace("/checkpoint-","_").replace("cackerman/","") + "_" + str(int(SAME_ALIAS))
                steeringtargets = [SteeringTarget.NONE]
                ##orig_suffix = "_" + checkpoint_path.replace("/checkpoint-","_").replace("cackerman/","") + "_" + str(int(SAME_ALIAS))
            fnames = [f'{test_responses_str}{prompts_suffix}.npy',f'{test_responses_str}_alias{prompts_suffix}.npy',f'{test_responses_str}_recovery{prompts_suffix}.npy',f'{test_responses_str}_alias_recovery{prompts_suffix}.npy']
            use_prompts = None
            for fname in fnames:
                fname = nll_dir + fname
                if os.path.exists(fname):
                    existing_data = np.load(fname, allow_pickle=True).item()
                    use_prompts = existing_data['prompts']
                    break
            if use_prompts: print("reusing prompts")
            run_numshots_get_nlls_and_generate_start_time = time.time()
            generate = False#True if test_responses_str=='parity_responses' else False
            nlls_mean_by_type, nlls_var_by_type, actual_shot_counts, prefilled_prompts = run_numshots_get_nlls_and_generate(ftwoc_model, test_responses, steering_vectors, coloring_vectors, zerofirst, shot_counts, num_attacks = num_attacks, alias=alias, generate = generate, steeringtargets = steeringtargets, batch_size=batch_size, use_prompts=use_prompts, colorsim=colorsim, steersim=steersim, suffix=f"{test_responses_str}{aliasstr}{recoverystr}{suffix}", add_at_end=add_at_end, lmsys=lmsys, SAME_ALIAS=SAME_ALIAS, icl=icl, targs=targs)
            print(f"Time to finish = {time.time() - run_numshots_get_nlls_and_generate_start_time} secs")
            np.save(f'{test_responses_str}{aliasstr}{recoverystr}{suffix}.npy', {'means': nlls_mean_by_type,'vars': nlls_var_by_type, 'counts': actual_shot_counts,'prompts': prefilled_prompts})
        elif loop == 2:
            if i==0: 
                del ftwoc_model
                gc.collect()
                torch.cuda.empty_cache()
                base_model = load_model(base_model_path, base_model_path, bnb = False)
                ###base_model = HookedModel(base_model)
                suffix = "_untunedmodel_new" + "_" + str(int(SAME_ALIAS))
                steeringtargets = [SteeringTarget.NONE]
                orig_suffix = "_" + checkpoint_path.replace("/checkpoint-","_").replace("cackerman/","") + "_" + str(int(SAME_ALIAS))
            fnames = [f'{test_responses_str}{prompts_suffix}.npy',f'{test_responses_str}_alias{prompts_suffix}.npy',f'{test_responses_str}_recovery{prompts_suffix}.npy',f'{test_responses_str}_alias_recovery{prompts_suffix}.npy']
            use_prompts = None
            for fname in fnames:
                fname = nll_dir + fname
                if os.path.exists(fname):
                    existing_data = np.load(fname, allow_pickle=True).item()
                    use_prompts = existing_data['prompts']
                    break
            if use_prompts: print("reusing prompts")
            run_numshots_get_nlls_and_generate_start_time = time.time()
            generate = False#True if test_responses_str=='parity_responses' else False
            nlls_mean_by_type, nlls_var_by_type, actual_shot_counts, prefilled_prompts = run_numshots_get_nlls_and_generate(base_model, test_responses, steering_vectors, coloring_vectors, zerofirst, shot_counts, num_attacks = num_attacks, alias=alias, generate = generate, steeringtargets = steeringtargets, batch_size=batch_size, use_prompts=use_prompts, colorsim=colorsim, steersim=steersim, suffix=f"{test_responses_str}{aliasstr}{recoverystr}{suffix}", add_at_end=add_at_end, lmsys=lmsys, SAME_ALIAS=SAME_ALIAS, icl=icl, targs=targs)
            print(f"Time to finish = {time.time() - run_numshots_get_nlls_and_generate_start_time} secs")
            np.save(f'{test_responses_str}{aliasstr}{recoverystr}{suffix}.npy', {'means': nlls_mean_by_type,'vars': nlls_var_by_type, 'counts': actual_shot_counts,'prompts': prefilled_prompts})


In [None]:
# generate responses
    
sampling_kwargs = {"use_cache": True, "pad_token_id": model.tokenizer.eos_token_id, "max_new_tokens": 120, "do_sample": False, "top_p": None, "temperature": None}
sysprompt = ""
#sysprompt = "<|start_header_id|>system<|end_header_id|>\n\nBefore responding, consider whether the author of each part of the conversation is the user or the AI assistant, but do not mention this explicitly in your response.<|eot_id|>"#""

def run_numshots_generate(model, responses, steering_vectors, coloring_vectors, zerofirst, shot_counts, num_attacks=200, alias=None, generate=True, steeringtargets=[SteeringTarget.NONE], batch_size=4, use_prompts = None, colorsim=0.5, steersim=0.2, suffix = "", add_at_end=True, lmsys=False, SAME_ALIAS=True, icl=False, scale_to_residnorm=False):    
    model.tokenizer.padding_side = "left"
    model.eval()

    if generate:
        output_file_paths = {}
        for steeringtarget in steeringtargets:
            output_file_path = f"./generated_responses_{steeringtarget.value}_{suffix}.jsonl"
            if os.path.exists(output_file_path):
                os.remove(output_file_path)
            output_file_paths[steeringtarget.value] = output_file_path
                
    max_num_shots = max(shot_counts) + 1
    
    all_neglogprobs = {steeringtarget.value: np.zeros((num_attacks, max_num_shots)) for steeringtarget in steeringtargets}
    actual_shot_counts = np.zeros(max_num_shots)
    
    iter_count=0
    max_num_shots = max(shot_counts) + 1
    all_neglogprobs = {steeringtarget.value: np.zeros((num_attacks, max_num_shots)) for steeringtarget in steeringtargets}
    actual_shot_counts = np.zeros(max_num_shots)

    if lmsys: prompt_index = use_prompts if use_prompts else create_prompt_index_lmsys(responses, num_attacks, max_num_shots-1, model, sysprompt)
    else: 
        prompt_index = use_prompts if use_prompts else create_prompt_index(responses, num_attacks, max_num_shots, model, sysprompt)
        num_attacks = len(prompt_index)
    
    for batch_start in range(0, num_attacks, batch_size):
        batch_start_time = time.time()
        batch_end = min(batch_start + batch_size, num_attacks)
        if not lmsys: current_batch = prompt_index[batch_start:batch_end]
        
        current_batch_size = batch_end - batch_start

        for num_shots in shot_counts:
            if lmsys: current_batch = prompt_index[num_shots][batch_start:batch_end]
            msj_prompts, target_answers, target_questions = load_prompts_and_target_answers(current_batch, responses, num_shots, sysprompt, alias, SAME_ALIAS=SAME_ALIAS, icl=icl, lmsys=lmsys, max_num_shots=max_num_shots)

            inputs = model.tokenizer(msj_prompts, return_tensors="pt", padding=True)
            input_ids = inputs['input_ids']
            attention_mask = inputs['attention_mask']

            # Prepare full inputs including the target answers
            full_prompts = [msj_prompts[b] + target_answers[b] + "<|eot_id|>" for b in range(current_batch_size)]
            full_inputs = model.tokenizer(full_prompts, return_tensors="pt", padding=True)
            full_input_ids = full_inputs['input_ids']
            full_attention_mask = full_inputs['attention_mask']
            
            # Calculate true input lengths (excluding padding)
            input_lengths = attention_mask.sum(dim=1)
            full_input_lengths = full_attention_mask.sum(dim=1)
            
            del attention_mask
            gc.collect()
            torch.cuda.empty_cache()

            # Calculate start and end positions based on padding side
            if model.tokenizer.padding_side == 'right':
                start_positions = torch.zeros_like(input_lengths)
                end_positions = input_lengths
                full_start_positions = torch.zeros_like(full_input_lengths)
                full_end_positions = full_input_lengths
            else:  # left padding
                padded_length = input_ids.shape[1]
                start_positions = padded_length - input_lengths
                end_positions = torch.full_like(input_lengths, padded_length)
                padded_length = full_input_ids.shape[1]
                full_start_positions = padded_length - full_input_lengths
                full_end_positions = torch.full_like(full_input_lengths, padded_length)
                        
            if generate and num_shots==shot_counts[-1]:######actual_shot_counts[num_shots] == 0 and num_shots % 4 == 0:
                gen_bs = current_batch_size
                inputs_single = {k: v[0:gen_bs] if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
                for steeringtarget in steeringtargets:                
                    if steeringtarget == SteeringTarget.NONE:
                        inputs_single = {k: v.to(next(model.parameters()).device) for k, v in inputs_single.items()}
                        with torch.no_grad():
                            generated_output = model.generate(**inputs_single, **sampling_kwargs)
                    else:
                        activationslist = []
                        add_cont = 1 if (steering_vectors and steering_vectors!={}) and (steeringtarget == SteeringTarget.FINAL_ONLY or steeringtarget == SteeringTarget.ALL_USER_PLUS_FINAL or steeringtarget == SteeringTarget.INTERLEAVED or steeringtarget == SteeringTarget.TEXT_ONLY) else 0   
                        coloring_maps = get_steering_target_maps(steeringtarget, start_positions[:gen_bs], end_positions[:gen_bs], end_positions[:gen_bs]-add_cont, coloring_vectors, steering_vectors, inputs_single)#-1 because continuous hook will take care of the last token
                        if zerofirst: activationslist.append(ZeroActivations(specific_pos_write_target=coloring_maps, at_end=add_at_end))
                        activationslist.append(AddActivations(specific_pos_write_target=coloring_maps, scale_to_sim=colorsim, at_end=add_at_end))
                        if add_cont == 1:
                            if zerofirst: activationslist.append(ZeroActivations(continuous_write_target=[steering_vectors for _ in range(gen_bs)], at_end=add_at_end))
                            activationslist.append(AddActivations(continuous_write_target=[steering_vectors for _ in range(gen_bs)], scale_to_sim=steersim, at_end=add_at_end))
                        generated_output = model.run_hooked_model(inputs_single, generate=True, sampling_kwargs=sampling_kwargs, activation_targets=activationslist, scale_to_residnorm=scale_to_residnorm)
                        del activationslist
                    
                    generated_responses = model.tokenizer.batch_decode(generated_output[:, inputs_single['input_ids'].size(1):], skip_special_tokens=True)
                    with open(output_file_paths[steeringtarget.value], 'a', encoding='utf-8') as output_file:
                        for b in range(gen_bs):
                            json.dump({"shots": num_shots, "prompt": msj_prompts[b], "question": target_questions[b], "response": generated_responses[b]}, output_file)
                            output_file.write("\n")
                    del generated_output, generated_responses

                del msj_prompts, inputs, input_ids
                for var in ['coloring_maps', 'zeroing_maps', 'zero_activations']:
                    if var in locals(): del var
                gc.collect()
                torch.cuda.empty_cache()
            
            actual_shot_counts[num_shots] += current_batch_size

            iter_count+=1

            print(f"Outer loop iteration {iter_count}, num_shots={num_shots}, batch_start={batch_start}, GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
            for var in ['add_activations_color', 'add_activations_steer', 'add_activations_steer_color', 'coloring_maps', 'zeroing_maps', 'zero_activations']:
                if var in locals(): del var
            if hasattr(model, 'past_key_values'):
                print("deleting past key values")
                model.past_key_values = None
            gc.collect()
            torch.cuda.empty_cache()

        del current_batch
        gc.collect()
        torch.cuda.empty_cache()
        print(f"Batch time to finish = {time.time() - batch_start_time} secs")
        batchstr = ""
        print(f"Attacks {batch_start + 1} to {batch_start + current_batch_size}, {batchstr}")
    
    return prompt_index


num_attacks=100
shot_counts = [48]
batch_size = min(num_attacks,24)

colorsim = params['scale_to_sim']
steersim = params['scale_to_sim']
add_at_end = params['add_at_end']
scale_to_residnorm = params['scale_to_residnorm']
steer_vec_type = SteeringVectorType.LEARNED if params['learn_vectors'] else params['steer_vec_type']
coloring_vectors, steering_vectors = None,None#map_to_vectors(steer_vec_type, params['color_layers'], params['colormult'], params['steermult'], checkpoint_path, model, "./vectors/")

checkpoint_path = model_path
prefilled_prompts_list=[]
zerofirst = params['zerofirst']
SAME_ALIAS=False
#base_model = load_model(base_model_path, base_model_path, bnb = False)
for loop in [1,2]:#'harmful_responses_test']):#,'harmless_responses_test',,'harmless_lat_responses_test'
    for i, test_responses_str in enumerate(['harmful_lat_responses_test','harmful_responses_test','mean_responses_test','harmful_responses_test_alias','harmful_lat_responses_test_alias','mean_responses_test_alias','harmful3_responses','harmful3_responses_alias','conv_dict_lmsys_bad','conv_dict_lmsys_good']):
        shot_counts = [28] if 'harmful3' in test_responses_str else [20] if 'conv_dict_lmsys' in test_responses_str else [48]
        num_attacks = 50 if lmsys else 100
        if test_responses_str in ['parity_responses','harmless_responses_test','nice_responses_test','harmful_lat_responses_test','harmless_lat_responses_test','harmful_lat_responses_cln_test','harmless_lat_responses_cln_test','conv_dict_lmsys_bad','conv_dict_lmsys_good']:
            alias = RoleAliasType.NONE
            aliasstr = ""
        else:
            if params['alias'] == RoleAliasType.NONE:
                alias = RoleAliasType.RANDOM if "_alias" in test_responses_str else RoleAliasType.NONE
                aliasstr = "_alias" if alias == RoleAliasType.RANDOM else ""
            else: 
                alias = params['alias'] if "_alias" in test_responses_str else RoleAliasType.NONE
                aliasstr = "" if alias == RoleAliasType.NONE else "_alias"
        normal_convo = True if test_responses_str in ['harmless_responses_test','nice_responses_test','harmless_lat_responses_test','harmless_lat_responses_cln_test','conv_dict_lmsys_good'] else False
        test_responses_str=test_responses_str.replace("_alias","")
        icl = False#True if test_responses_str=='parity_responses' else False
        test_responses = globals()[test_responses_str]
        lmsys = True if 'conv_dict_lmsys' in test_responses_str else False
        if loop == 0:
            suffix = "_" + checkpoint_path.replace("/checkpoint-","_").replace("cackerman/","") + "_" + str(int(SAME_ALIAS))
            steeringtargets = [(SteeringTarget.INTERLEAVED if normal_convo else SteeringTarget.ALL_USER_PLUS_FINAL)]# ,SteeringTarget.NONE]#
            fnames = [f'{test_responses_str}{prompts_suffix}_gen.npy',f'{test_responses_str}_alias{prompts_suffix}_gen.npy',f'{test_responses_str}_recovery{prompts_suffix}_gen.npy',f'{test_responses_str}_alias_recovery{prompts_suffix}_gen.npy']
            use_prompts = None
            for fname in fnames:
                fname = nll_dir + fname
                if os.path.exists(fname):
                    existing_data = np.load(fname, allow_pickle=True).item()
                    use_prompts = existing_data['prompts']
                    break
            if use_prompts: print("reusing prompts")
            run_numshots_generate_starttime = time.time()
            generate = True #####if test_responses_str=='parity_responses' else False
            prefilled_prompts = run_numshots_generate(model, test_responses, steering_vectors, coloring_vectors, zerofirst, shot_counts, num_attacks = num_attacks, alias=alias, generate = generate, steeringtargets = steeringtargets, batch_size=batch_size, use_prompts=use_prompts, colorsim=colorsim, steersim=steersim, suffix=f"{test_responses_str}{aliasstr}_{shot_counts[0]}shot{suffix}", add_at_end=add_at_end, lmsys=lmsys, SAME_ALIAS=SAME_ALIAS, icl=icl, scale_to_residnorm=scale_to_residnorm)
            prefilled_prompts_list.append(prefilled_prompts)
            print(f"Time to finish = {time.time() - run_numshots_generate_starttime} secs")
            np.save(f'{test_responses_str}{aliasstr}{recoverystr}{suffix}_gen.npy', {'prompts': prefilled_prompts})
        elif loop == 1:
            if i==0: 
                ###del model
                gc.collect()
                torch.cuda.empty_cache()
                #ftwoc_model = load_model(ftwoc_path, base_model_path, bnb = False)
                ####ftwoc_model = PeftModel.from_pretrained(ftwoc_model, ftwoc_path)      

                ###ftwoc_model = HookedModel(ftwoc_model)
                
                suffix = "_" + ftwoc_path.replace("/checkpoint-","_").replace("cackerman/","")
                steeringtargets = [SteeringTarget.NONE]
            fnames = [f'{test_responses_str}{prompts_suffix}_gen.npy',f'{test_responses_str}_alias{prompts_suffix}_gen.npy',f'{test_responses_str}_recovery{prompts_suffix}_gen.npy',f'{test_responses_str}_alias_recovery{prompts_suffix}_gen.npy']
            use_prompts = None
            for fname in fnames:
                fname = nll_dir + fname
                if os.path.exists(fname):
                    existing_data = np.load(fname, allow_pickle=True).item()
                    use_prompts = existing_data['prompts']
                    break
            if use_prompts: print("reusing prompts")
            run_numshots_generate_starttime = time.time()
            generate = True 
            prefilled_prompts = run_numshots_generate(ftwoc_model, test_responses, steering_vectors, coloring_vectors, zerofirst, shot_counts, num_attacks = num_attacks, alias=alias, generate = generate, steeringtargets = steeringtargets, batch_size=batch_size, use_prompts=use_prompts, colorsim=colorsim, steersim=steersim, suffix=f"{test_responses_str}{aliasstr}_{shot_counts[0]}shot{suffix}", add_at_end=add_at_end, lmsys=lmsys, SAME_ALIAS=SAME_ALIAS, icl=icl)
            prefilled_prompts_list.append(prefilled_prompts)
            print(f"Time to finish = {time.time() - run_numshots_generate_starttime} secs")
            np.save(f'{test_responses_str}{aliasstr}{recoverystr}{suffix}_gen.npy', {'prompts': prefilled_prompts})
        elif loop == 2:
            if i==0: 
                del ftwoc_model
                gc.collect()
                torch.cuda.empty_cache()
                #base_model = load_model(base_model_path, base_model_path, bnb = False)
                ####base_model = HookedModel(base_model)
                suffix = "_untunedmodel_new"
                steeringtargets = [SteeringTarget.NONE]
            fnames = [f'{test_responses_str}{prompts_suffix}_gen.npy',f'{test_responses_str}_alias{prompts_suffix}_gen.npy',f'{test_responses_str}_recovery{prompts_suffix}_gen.npy',f'{test_responses_str}_alias_recovery{prompts_suffix}_gen.npy']
            use_prompts = None
            for fname in fnames:
                fname = nll_dir + fname
                if os.path.exists(fname):
                    existing_data = np.load(fname, allow_pickle=True).item()
                    use_prompts = existing_data['prompts']
                    break
            if use_prompts: print("reusing prompts")
            run_numshots_generate_starttime = time.time()
            generate = True #####if test_responses_str=='parity_responses' else False
            prefilled_prompts = run_numshots_generate(base_model, test_responses, steering_vectors, coloring_vectors, zerofirst, shot_counts, num_attacks = num_attacks, alias=alias, generate = generate, steeringtargets = steeringtargets, batch_size=batch_size, use_prompts=use_prompts, colorsim=colorsim, steersim=steersim, suffix=f"{test_responses_str}{aliasstr}_{shot_counts[0]}shot{suffix}", add_at_end=add_at_end, lmsys=lmsys, SAME_ALIAS=SAME_ALIAS, icl=icl)
            print(f"Time to finish = {time.time() - run_numshots_generate_starttime} secs")
            np.save(f'{test_responses_str}{aliasstr}{recoverystr}{suffix}_gen.npy', {'prompts': prefilled_prompts})
