In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, \
  BitsAndBytesConfig, GPTQConfig
import os

while "notebooks" in os.getcwd():
    os.chdir("..")

from time import time
from tqdm import tqdm
import torch
from langdetect import detect
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from huggingface_hub import notebook_login
from datasets import load_dataset
import math
from typing import List, Optional, Tuple, Union
from torch import nn
from tqdm import tqdm
from IPython.display import clear_output
import warnings
warnings.filterwarnings("ignore")
from copy import deepcopy
import gc
from sklearn.metrics import  roc_auc_score, average_precision_score

from src.utils import rotate_half, apply_rotary_pos_emb, repeat_kv, \
    get_context_length, get_generated_text, FileReader, is_text_in_language, rolling_mean, insert_needle

from src.attention_saver import Mistral7BAttentionSaver
from src.influence.influence import Influence, AttentionRollout

import scienceplots
plt.style.use(['science','no-latex'])

In [3]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    cache_dir = "/Data"    
)


base_model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    quantization_config = quantization_config,
    device_map="auto",
    attn_implementation="eager",
    cache_dir = "/Data"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
saver = Mistral7BAttentionSaver(
    base_model,
    tokenizer,
    should_save_params=True
)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

tqdm.pandas()

df = load_dataset("stas/openwebtext-10k", cache_dir="/Data")['train'].to_pandas()
df["text_len"] = df["text"].apply(lambda x: len(x.split(" ")))
df['context_length'] = df['text'].progress_apply(get_context_length, tokenizer = tokenizer)


chunks = []
for n in range (8):
    samples = df.query(f"context_length > {500*n} & context_length < {500*(n+1)}")\
        .sample(15, random_state = 43)
    
    chunks.append(samples)

# study_df = pd.concat(chunks)\
#     .sort_values("context_length", ascending = False)

indexes = [8263, 5418, 9572, 6251, 2927, 6800, 7716, 408, 4851, 8568, 6944,
       3651, 247, 703, 1176, 9336, 6207, 9683, 8572, 2193, 6571, 5087,
       4122, 4791, 8952, 1654, 3119, 9263, 6594, 9948, 3177, 1569, 1686,
       1726, 6939, 7577, 1799, 8927, 6281, 9942, 5392, 7620, 9842, 3979,
       6532, 5037, 8052, 2590, 8459, 1172, 6969, 2731, 5064, 3526, 6461,
       6565, 2537, 9679, 695, 2235, 8894, 7514, 2454, 1656, 7796, 9852,
       8200, 7016, 6692, 3507, 3001, 8227, 6280, 6537, 8620, 9484, 2028,
       5560, 5645, 412, 6559, 1497, 928, 7862, 6798, 6874, 4734, 2956,
       3601, 6201, 9017, 2673, 433, 4861, 5407, 9311, 6810, 9155, 2626,
       6219, 9301, 3564, 1413, 7146, 7169, 3749, 9734, 5389, 8266, 3224,
       1391, 9375, 697, 2319, 3099, 8065, 5834, 8867, 8841, 5378]

study_df = df.iloc[indexes]

100%|██████████| 10000/10000 [00:20<00:00, 477.29it/s]


## Adding the needle

In [5]:
needle = "\nThe best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.\n"
question = "Your objective is to answer the following question based on the context: \nWhat is the best thing to do in San Francisco? \nDon't give information outside the document or repeat our findings"

In [6]:
all_df = []

instructions = [needle]

for instruction in instructions:
    for depth_percent in tqdm(range(0, 125, 25)):

        percent_df = study_df.apply(
            insert_needle, 
            depth_percent = depth_percent, 
            question = question,
            needle = instruction, 
            axis = 1
        )

        all_df.append(percent_df)




  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:00<00:00,  7.85it/s]


In [7]:
samples_df = pd.concat(all_df)\
    .reset_index()

In [23]:
samples_df['index']

0      8263
1      5418
2      9572
3      6251
4      2927
       ... 
595    8065
596    5834
597    8867
598    8841
599    5378
Name: index, Length: 600, dtype: int64

In [24]:
saver.set_delta_attention(0)
raw_attention_dict = dict()

for idx, row in tqdm(samples_df.iterrows(), total= len(samples_df)):

    saver.reset_internal_parameters()
    instruction = row['needle']
    text = row['new_text']
    index = row['index']
    depth = row['depth']
    prompt = f"{instruction}\n{text}"

    message = [ {"role": "user", "content": prompt}]

    template = tokenizer.apply_chat_template(
        message,
        tokenize= False
    )

    splits = template.split(instruction)
    initial_prompt = splits[0]
    context = instruction.join(splits[1:])

    assert (hash(initial_prompt+instruction+context) == hash(template)), "Error in spliting strings. Initial and final string does not match"

    initial_tokens = tokenizer.encode(initial_prompt, return_tensors='pt')
    instruction_tokens = tokenizer.encode(
        instruction, 
        return_tensors='pt', 
        add_special_tokens=False
    )
    context_tokens = tokenizer.encode(
        context, 
        return_tensors='pt',
        add_special_tokens=False
    )

    start_idx = initial_tokens.size(1)
    end_idx = start_idx + instruction_tokens.size(1)

    saver.set_reference_tokens(start_idx, end_idx)
    
    tokens = torch.concat([
        initial_tokens.squeeze(), 
        instruction_tokens.squeeze(),
        context_tokens.squeeze()
    ]).unsqueeze(0)

    q = tokenizer.decode(tokens.squeeze()[start_idx: end_idx])

    assert instruction in q, "Error in tokenization. Not giving attention to correct tokens"

    tokens2 = tokenizer(template, return_tensors='pt')

    assert (abs(tokens.shape[1] - tokens2['input_ids'].shape[1]) <=5 ), "Error in tokenization. Tokens do not match"

    with torch.no_grad():
        out = base_model(tokens)

    last_attn_matrix = saver.internal_parameters[-1]\
        ['avg_attention_heads']\
        .squeeze()
    last_token_importances = last_attn_matrix\
        [-1, start_idx:end_idx]\
        .mean()\
        .item()

    raw_attention_dict[(index, depth)] = last_token_importances

  0%|          | 0/600 [00:00<?, ?it/s]

100%|██████████| 600/600 [10:33<00:00,  1.06s/it]


In [29]:
pd.Series(raw_attention_dict).reset_index()\
    .rename(columns = {
        "level_0": "index", 
        "level_1": "depth",
        0 : "attention_scores"
    })\
    .to_parquet("data/influences/needle/raw_attention_scores.parquet")

In [30]:
attention_scores_df = pd.read_parquet(
    "data/influences/needle/raw_attention_scores.parquet"
)

In [32]:
generated_text_df = pd.read_parquet("data/influences/needle/generated_text4.parquet")\
    .reset_index()\
    .rename(columns = {"level_0": "text_idx", "level_1":"depth", 0: "generated_text"})


generated_text_df['generated_text'] = generated_text_df['generated_text']\
    .apply(lambda x: x.split('[/INST]')[1])

In [33]:
generated_text_df

Unnamed: 0,index,text_idx,depth,epoch,generated_text
0,0,8263,0.000000,0,The best thing to do in San Francisco is eat ...
1,1,8263,49.181606,0,"Based on the context, there is no direct answ..."
2,2,8263,100.000000,0,"Based on the provided context, the best thing..."
3,3,8263,74.175110,0,"The best thing to do in San Francisco, based ..."
4,4,8263,24.863601,0,"Based on the input provided, the best thing t..."
...,...,...,...,...,...
5975,5975,8841,55.088702,9,There is no evidence to suggest that anything...
5976,5976,5378,0.000000,9,The best thing to do in San Francisco is not ...
5977,5977,5378,100.000000,9,"Based on the given context, the best thing to..."
5978,5978,5378,18.204804,9,"Unfortunately, I cannot provide an answer to ..."


In [46]:
attention_probabilities_df = pd.merge(
    generated_text_df, 
    attention_scores_df,
    left_on = ["text_idx",'depth'],
    right_on= ["index",'depth']
)

attention_probabilities_df['score'] = attention_probabilities_df['generated_text'].apply(lambda x: 'dolores' in x.lower())

In [47]:
attention_probabilities_df.groupby(['text_idx', 'depth'])\
    [['attention_scores', 'score']]\
    .mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,attention_scores,score
text_idx,depth,Unnamed: 2_level_1,Unnamed: 3_level_1
247,0.000000,0.002834,0.6
247,24.683399,0.002472,0.3
247,49.746719,0.002420,0.2
247,74.441630,0.002537,0.0
247,100.000000,0.002380,0.4
...,...,...,...
9948,0.000000,0.003086,0.6
9948,24.416404,0.003000,0.2
9948,49.526814,0.002886,0.2
9948,74.458465,0.002993,0.1


In [42]:
target = attention_probabilities_df['score']
probas = attention_probabilities_df['attention_scores']

In [43]:
roc_auc_score(target.astype(int), probas)

0.48087500597446187

In [44]:
target.corr(probas)

-0.03408837057062263