In [1]:
import pandas as pd
import numpy as np
from graspologic.embed import ClassicalMDS
import matplotlib.pyplot as plt
from matplotlib import cm
from scipy.interpolate import RBFInterpolator, LinearNDInterpolator
import random

import perturbations
from string import ascii_letters
ascii_letters += " "

import torch

from tqdm import tqdm

%matplotlib inline

  _edge_swap_numba = nb.jit(_edge_swap, nopython=False)


In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model_name_or_path = "TheBloke/Llama-2-7b-Chat-GPTQ"

model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             trust_remote_code=False,
                                             revision="main",
                                            output_hidden_states=True)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
model.to("cuda")

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (rotary_emb): LlamaRotaryEmbedding()
          (k_proj): QuantLinear()
          (o_proj): QuantLinear()
          (q_proj): QuantLinear()
          (v_proj): QuantLinear()
        )
        (mlp): LlamaMLP(
          (act_fn): SiLUActivation()
          (down_proj): QuantLinear()
          (gate_proj): QuantLinear()
          (up_proj): QuantLinear()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

In [4]:
instructions = "Answer only with 'yes' or 'no'. Other answers will be discarded."
prompt = "Was RA Fisher a great man?"

prompt_template = f'''<s>[INST] <<SYS>>
{instructions}
<</SYS>>

{prompt} [/INST]
'''

input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
output = model.generate(inputs=input_ids, temperature=0.9, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=3)
print(tokenizer.decode(output[0]))



<s><s>[INST] <<SYS>>
Answer only with 'yes' or 'no'. Other answers will be discarded.
<</SYS>>

Was RA Fisher a great man? [/INST]
Yes</s>


In [5]:
fisher_data_file_path = '/home/ubuntu/pnma/files/RA-Fisher.csv'
fisher_data = pd.read_csv(fisher_data_file_path)
strings = fisher_data['string']
labels = fisher_data['label']
C = np.array([0 if ell=='statistics' else 1 for ell in labels])

In [13]:
np.random.seed(1)

perturbations_dict = {}
length_list = [0, 1, 10, 100, 1000]
n_perturbations=10

p_strings = []
for i in range(n_perturbations):
    s = ""
    for j in range(1000):
        s+= random.choice(ascii_letters)

    p_strings.append(s)


for i, p in enumerate(tqdm(length_list)):    
    perturbations_dict[p] = {}
    perturber = AppendPerturbation(p)
    for s in strings:
        if p == 0:
            perturbations_dict[p][s] = [perturber.perturb(s, new_appendix=False)]
        else:
            perturbations_dict[p][s] = []
            for ii in range(n_perturbations):
                perturber.appendix=p_strings[ii][:p]
                perturbations_dict[p][s].append(perturber.perturb(s, new_appendix=False)) 

100%|███████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 1734.76it/s]


In [14]:
def get_response(model, input_ids, tokenizer, generate_kwargs={'max_new_tokens': 3}):    
    output = model.generate(inputs=input_ids, **generate_kwargs)
    response = tokenizer.decode(output[0])

    response = response.split('[/INST]')[-1]
    response = response.split('</s>')[0]
    response = response.lower()

    response = response.split(' ')[-1]
    response = response.split('.')[0]

    return response
    

def get_embedding(model, input_ids):
    output =  model(input_ids)
    embedding = np.mean(output.hidden_states[0].detach().cpu().numpy(), axis=1).flatten()
    
    return embedding


def get_formatted_prompt(prompt, context="", instruction=None):
    if instruction is None:
        instruction = "Answer only with 'yes' or 'no' in English."

    if len(context) == 0:
        context_and_prompt = prompt

    else:
        context_and_prompt = f'{context} {prompt}'

    formatted_prompt = f'''<s>[INST] <<SYS>>
        {instruction}
        <</SYS>>
        
        {context_and_prompt} [/INST]'''

    return formatted_prompt
    

def get_strings_stratified(strings, labels, n=10, S=10):
    if S > n:
        S = n
    
    unique_labels = np.unique(labels)
    
    stats_strings_indices = np.random.choice(np.where(labels==unique_labels[1])[0], S, replace=True) 
    eugenics_strings_indices = np.random.choice(np.where(labels==unique_labels[0])[0], n-S, replace=True)

    selected_strings = [strings[i] for i in stats_strings_indices] + [strings[i] for i in eugenics_strings_indices]
    random.shuffle(selected_strings)
        
    return selected_strings


def combine_strings(string_list):
    s = ""
    for s_ in string_list:
        s += " " + s_

    return s[1:]


def get_context_and_prompt(prompt, strings, labels, C=2, S=2):
    string_list = get_strings_stratified(strings, labels, C, S)
    context_string = combine_strings(string_list)

    return context_string + " " + prompt

def estimate_p(model, input_ids, tokenizer, generate_kwargs, n_responses, max_c):
    c=0
    valid_responses=0
    phat=0
    while valid_responses<n_responses and c<max_c:
        response = get_response(model, input_ids, tokenizer, generate_kwargs)

        # print(response)

        if response.encode('utf-8') in YES_LIST:
            valid_responses+=1
            phat+=1
        elif response.encode('utf-8') in NO_LIST:
            valid_responses+=1
        c+=1

    phat /= valid_responses

    return phat

In [19]:
np.random.seed(1)

NO_LIST = [b'no', b'\xe2\x9d\x8c', b'\xe2\x98\xb9', b'\xf0\x9f\x98\x90']
YES_LIST = [b'yes', b'\xf0\x9f\x98\x8a']

skip_words = ['not', 'was']
n_responses = 20

generate_kwargs = {
    'temperature':0.8, 
    'do_sample':True, 
    'top_p':0.95, 
    'top_k':40, 
    'max_new_tokens':10
}


# embeddings_dict = {}
# phats_dict = {}

max_c=100

for length in [0]:
    if length not in embeddings_dict.keys():
        embeddings_dict[length] = {}
        phats_dict[length] = {}
    for s in tqdm(strings):
        if s not in embeddings_dict[length].keys():
            embeddings_dict[length][s] = []
            phats_dict[length][s] = []
    
        if length == 0:
            formatted_prompt = get_formatted_prompt(prompt, context=s)
            
            input_ids = tokenizer(formatted_prompt, return_tensors='pt').input_ids.cuda()
            
            embeddings_dict[length][s].append(get_embedding(model, input_ids))

            phats_dict[length][s].append(estimate_p(model, input_ids, tokenizer, generate_kwargs, n_responses, max_c))

            del input_ids

            continue

        n_perturbed_strings_completed=len(phats_dict[length][s])

        for i, perturbed_string in enumerate(perturbations_dict[length][s][n_perturbed_strings_completed:], n_perturbed_strings_completed):
            formatted_prompt = get_formatted_prompt(prompt, context=perturbed_string)
    
            input_ids = tokenizer(formatted_prompt, return_tensors='pt').input_ids.cuda()
        
            embedding = get_embedding(model, input_ids)
            embeddings_dict[length][s].append(embedding)
        
            phat = estimate_p(model, input_ids, tokenizer, generate_kwargs, n_responses, max_c)
            phats_dict[length][s].append(phat)

            del input_ids

100%|███████████████████████████████████████████████████████████████████████████████████| 50/50 [03:53<00:00,  4.68s/it]


In [20]:
import pickle

embeddings_and_phats = {'embeddings': embeddings_dict, 'phats': phats_dict}
pickle.dump(embeddings_and_phats, open('/home/ubuntu/data/embeddings_and_phats_with_paired_appendix.p', 'wb'))