In [None]:
import torch
import transformers
import sys
import os
import matplotlib.pyplot as plt
from collections import OrderedDict

sys.path.append("..")

from utils import experiment_logger, attack_utility
from secalign_refactored import secalign, config
import algorithms.losses_experimental as losses_experimental

In [None]:
model_rel_path = "../secalign_refactored/secalign_models/meta-llama/Meta-Llama-3-8B-Instruct_dpo_NaiveCompletion_2024-11-12-17-59-06-resized"

os.environ["CUDA_VISIBLE_DEVICES"] = str(0)

load_model = True
if load_model:
    model, tokenizer, frontend_delimiters, _ = secalign.load_lora_model(model_rel_path, load_model=load_model, device_map="cpu")

    inst_delm = config.DELIMITERS[frontend_delimiters][0]
    data_delm = config.DELIMITERS[frontend_delimiters][1]
    resp_delm = config.DELIMITERS[frontend_delimiters][2]

    prompt_template = config.PROMPT_FORMAT[frontend_delimiters]
    model = model.eval()
    model.generation_config.pad_token_id = tokenizer.pad_token_id
    model.generation_config.temperature = 0.0
    model.generation_config.do_sample=False
else:
    model, tokenizer, frontend_delimiters, _ = None, None, None, None

In [None]:
LOG_FOLDER_PATH = "logs/univ_comp_run_6"

logger = experiment_logger.ExperimentLogger(LOG_FOLDER_PATH)

input_tokenized_data_list = next(logger.query({"variable_name": "input_tokenized_data_list"}))

best_tokens_dicts_by_num_examples = OrderedDict()

for best_tokens_dicts, metadata_dict in logger.query_with_metadata({"variable_name": "best_tokens_dicts_chunk"}):
    trace_id = metadata_dict["trace_id"]
    if trace_id not in best_tokens_dicts_by_num_examples:
        best_tokens_dicts_by_num_examples[trace_id] = []
    best_tokens_dicts_by_num_examples[trace_id].append(best_tokens_dicts)


In [None]:
IDX_TO_CARE_FOR = 0 # i.e. the one with the largest dataset
RELEVANT_BATCH_SIZES_TAKEN = [2, 4, 6, 8, 10]
relevant_best_tokens_dicts = best_tokens_dicts_by_num_examples[IDX_TO_CARE_FOR]

input_tokenized_dataset_stepwise = [attack_utility.update_all_tokens(best_token_dict, input_tokenized_data_list) for best_token_dict in relevant_best_tokens_dicts]

SENS_FREQ = 20

sensitivities_list = []
for step_num, input_tokenized_dataset in enumerate(input_tokenized_dataset_stepwise):
    if step_num % SENS_FREQ != 0:
        continue

    average_sens_at_step = losses_experimental.dataset_average_sensitivities(model, tokenizer, input_tokenized_dataset, None)
    sensitivities_list.append(average_sens_at_step)

