In [1]:
import torch, numpy
from transformers import AutoModelForCausalLM, AutoTokenizer
# for writing dataset out to data.json
import pandas as pd
import json
# for debugging purposes
import sys

torch.set_grad_enabled(False)
# check for GPU
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

True
1
NVIDIA A100-SXM4-80GB MIG 2g.20gb


Arguments to Modify Prompts Generated

In [2]:
# index of desired model
# index 0 - codellama
# index 1 - Phi3
# index 2 - StarCoder2
# index 3 - DeepSeek Coder
index = 0

# the integer values inside the parens
nums = [
    1,
    2,
    3,
    12,
    23,
    123,
    123456789
]

# NOTE: clean, corrupted prompts pair generation currently only supported for str
# constructor calls to be used for the prompts
constructor_calls = [
    "str",
    # "list",
    # "set",
    # "float",
    # "chr",
    # "bin"
]

# number of constructor calls to be included in the prompts
# for sequence constructor calls (list, set): 
# the number of constructor calls + 2 = number of parentheses
# for basic type constructor calls (str, float, chr, bin):
# the number of constructor calls + 1 = number of parentheses
num_constructors = [
    # 1,
    # 2,
    # 3,
    # 4,
    5,
    # 6,
    # 7,
    # 10,
]

# number of right parentheses to remove from the corrupted prompt
# its correct token is simply the token the prompt generates
corrupted_remove_parens = 2
corrupted_flag = True

In [3]:
def setup_model_and_tokenizer(i):
    projects_dir = "/projects/ziyuyao"
    cache_loc1 = projects_dir + "/codellama/codellama-7b"
    cache_loc2 = projects_dir + "/phi3"
    cache_loc3 = projects_dir + "/StarCoder2"
    cache_loc4 = projects_dir + "/DeepSeekCoder"

    model_args = [
        ("codellama/CodeLlama-7b-hf", cache_loc1, False, torch.float16),
        ("microsoft/Phi-3-mini-4k-instruct", cache_loc2, True, "auto"),
        ("bigcode/starcoder2-7b", cache_loc3, False, torch.bfloat16),
        ("deepseek-ai/deepseek-coder-6.7b-base", cache_loc4, True, torch.bfloat16),
    ]
    tokenizer = AutoTokenizer.from_pretrained(model_args[i][0], cache_dir=model_args[i][1], trust_remote_code=model_args[i][2])

    # torch.float16 or torch.bfloat16 used to fit onto 20gb GPU
    if i == 1:
        model = AutoModelForCausalLM.from_pretrained(model_args[i][0], cache_dir=model_args[i][1], trust_remote_code=model_args[i][2], torch_dtype=model_args[i][3], attn_implementation='eager')
    else: 
        model = AutoModelForCausalLM.from_pretrained(model_args[i][0], cache_dir=model_args[i][1], trust_remote_code=model_args[i][2], torch_dtype=model_args[i][3])

    return (model, tokenizer)


def create_prompts(corrupted_flag=False, corrupted_num_parens=0):
    prompts_dict = dict()

    for cc in constructor_calls:
        if cc == 'str':
            prompts_dict[cc] = [
                f'#print a string {num}\n' + 'print(' + (cc + '(') * x + str(num) + ')' * (x + 1) 
                for x in num_constructors 
                for num in nums
            ]
        else: # cc is 'list' or 'set'
            prompts_dict[cc] = [
                f'#print a {cc} containing {num}\n'+ 'print(' + (cc + '(') * x + 'tuple([' + str(num) + '])' + ')' * (x + 1) 
                for x in num_constructors 
                for num in nums
            ]

    count = 0
    # list of tuples (prompt, correct next token)
    clean_prompts = []
    corrupted_prompts = []
    for cc in constructor_calls:
        for j in range(0, len(prompts_dict[cc])):
            tokens = tokenizer.tokenize(prompts_dict[cc][j])
            prompt = prompts_dict[cc][j]
            correct_nt = tokens[-1]
            prompt = prompt[:len(prompt) - len(correct_nt)]

            if corrupted_flag:
                corrupted_prompt = prompt[:len(prompt) - corrupted_num_parens]
                corrupted_tokens = tokenizer.tokenize(corrupted_prompt)
                # check that the clean and corrupted prompts have the same number of tokens
                if len(tokens) - 1 == len(corrupted_tokens):
                    inputs = tokenizer(corrupted_prompt, return_tensors='pt').to("cuda")
                    nt_logits = model(**inputs)['logits'][:, -1, :]
                    corrupted_correct_nt = tokenizer.decode(torch.argmax(nt_logits).item())
                    print(f'Prompt: "{corrupted_prompt}"')
                    print(f'Counterfactual Next Token: "{corrupted_correct_nt}"')
                    clean_prompts.append((prompt, correct_nt))
                    corrupted_prompts.append((corrupted_prompt, corrupted_correct_nt, torch.argmax(nt_logits).item()))
                    count+=1
            else:
                clean_prompts.append((prompt, correct_nt))
                count+=1        

    print(f'Total {"Clean/Corrupted Prompt Pairs" if corrupted_flag else "Clean Prompts"}: {count}')
    prompts = [
        {'clean_prompt_tuple':clean_prompts[i], 'corrupted_prompt_tuple':corrupted_prompts[i] if corrupted_flag else None,} 
        for i in range(0, len(clean_prompts))
    ]

    return prompts

def populate_validate_dict_entry(d, idx, prompt_pair, current_output, clean_and_corrupted_flag=False):
    d['index'].append(idx)
    d['clean_prompts'].append(prompt_pair['clean_prompt_tuple'][0])
    d['clean_correct_outputs'].append(prompt_pair['clean_prompt_tuple'][1])
    d['clean_current_outputs'].append(current_output)

    if clean_and_corrupted_flag:
        d['corrupted_prompts'].append(prompt_pair['corrupted_prompt_tuple'][0])
        d['corrupted_correct_outputs'].append(prompt_pair['corrupted_prompt_tuple'][1])
        d['corrupted_correct_token_ids'].append(prompt_pair['corrupted_prompt_tuple'][2])
    # return d

def validate_prompts(model, tokenizer, prompts, clean_and_corrupted_flag=False):
    dict_keys = ['index', 'clean_prompts', 'clean_correct_outputs', 'clean_current_outputs', 'corrupted_prompts', 'corrupted_correct_outputs', 'corrupted_correct_token_ids']
    
    d = {key: [] for key in dict_keys + ['clean_correct_token_ids']}
    incorrect_d = {key: [] for key in dict_keys + ['clean_current_token_ids']}

    # prompts = [('#print a string 2\nprint(str(str(str(str(str(str(str(str(str(str(2', ')))')]
    for idx, prompt_pair in enumerate(prompts):
        inputs = tokenizer(prompt_pair['clean_prompt_tuple'][0], return_tensors='pt').to("cuda")
        nt_logits = model(**inputs)['logits'][:, -1, :]
        current_token_id = torch.argmax(nt_logits).item()
        # 
        current_output = tokenizer.decode(current_token_id)
        correct_output = prompt_pair['clean_prompt_tuple'][1]
        correct_bool = correct_output == current_output

        populate_validate_dict_entry(d if correct_bool else incorrect_d, idx, prompt_pair, current_output, True)
        if correct_bool:
            d['clean_correct_token_ids'].append(current_token_id)
        else:
            # finds the top logit instance of the correct answer token
            sorted_out, token_idxs = torch.sort(
                nt_logits,
                dim=-1,
                descending=True
            )
            for j in range(0, sorted_out.shape[1]):
                token = tokenizer.decode(token_idxs[0, j])
                if correct_output == token:
                    incorrect_d['clean_current_token_ids'].append(token_idxs[0, j].item())
                    break

            print('INCORRECT')
            print(f'Clean Prompt {idx}: {prompt_tuple[0]}')
            print(f'correct output = {correct_output}')
            print(f'current output = {current_output}')
            print('Full Generation: "' + prompt_tuple[0] + filling + '"')
    return (d, incorrect_d)


In [4]:
model, tokenizer = setup_model_and_tokenizer(index)
model.to('cuda')
prompts = create_prompts(corrupted_flag=True, corrupted_num_parens=corrupted_remove_parens)
d, incorrect_d = validate_prompts(model, tokenizer, prompts, clean_and_corrupted_flag=True)
corr_df = pd.DataFrame(d)
print(d)
print(incorrect_d)

incorr_df = pd.DataFrame(incorrect_d)
print(corr_df.loc[:, "clean_correct_outputs"].equals(corr_df.loc[:, "clean_current_outputs"]))

corr_df.sort_values('clean_correct_token_ids', inplace=True, kind='mergesort')
incorr_df.sort_values('clean_current_token_ids', inplace=True, kind='mergesort')

print(corr_df.value_counts('clean_correct_outputs'))
print(incorr_df.value_counts('clean_correct_outputs'))

rename_dict = {
    "clean_prompts": "clean", 
    "corrupted_prompts": "corrupted", 
    "clean_correct_token_ids": "label", 
    "corrupted_correct_token_ids": "counterfactual_label"
    }
corr_df.rename(columns=rename_dict, inplace=True)
keep_columns = list(rename_dict.values())
corr_df = corr_df[keep_columns]

incorr_df.rename(columns=rename_dict, inplace=True)
incorr_df = corr_df[keep_columns]

print(corr_df.head())
print(incorr_df.head())

if len(corr_df.index):
    corr_df.to_csv('correct_paren_data.csv', index=False)

# if len(incorr_df.index):
#     incorr_df.to_csv('incorrect_paren_data.csv', index=False)

# sys.exit(1)

# # convert to json and write out
# final_corr_dataset = corr_df.to_dict('records')
# final_incorr_dataset = incorr_df.to_dict('records')

# with open("../../data/codellama_incorrect_paren_data_v2.json", "w") as f:
#     json.dump(final_incorr_dataset, f, ensure_ascii=False, indent=4)

# with open("../../data/codellama_paren_data_v2.json", "w") as f:
#     json.dump(final_corr_dataset, f, ensure_ascii=False, indent=4)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


Prompt: "#print a string 1
print(str(str(str(str(str(1))"
Counterfactual Next Token: ")))"
Prompt: "#print a string 2
print(str(str(str(str(str(2))"
Counterfactual Next Token: ")))"
Prompt: "#print a string 3
print(str(str(str(str(str(3))"
Counterfactual Next Token: ")))"
Prompt: "#print a string 12
print(str(str(str(str(str(12))"
Counterfactual Next Token: ")))"
Prompt: "#print a string 23
print(str(str(str(str(str(23))"
Counterfactual Next Token: ")))"
Prompt: "#print a string 123
print(str(str(str(str(str(123))"
Counterfactual Next Token: ")))"
Prompt: "#print a string 123456789
print(str(str(str(str(str(123456789))"
Counterfactual Next Token: ")))"
Total Clean/Corrupted Prompt Pairs: 7
{'index': [0, 1, 2, 3, 4, 5, 6], 'clean_prompts': ['#print a string 1\nprint(str(str(str(str(str(1))))', '#print a string 2\nprint(str(str(str(str(str(2))))', '#print a string 3\nprint(str(str(str(str(str(3))))', '#print a string 12\nprint(str(str(str(str(str(12))))', '#print a string 23\nprint(str(s