In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, \
  BitsAndBytesConfig, GPTQConfig
import os

while "notebooks" in os.getcwd():
    os.chdir("..")

from time import time
from pathlib import Path
from tqdm import tqdm
import torch
from langdetect import detect
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from huggingface_hub import notebook_login
from datasets import load_dataset
import math
from typing import List, Optional, Tuple, Union
from torch import nn
from tqdm import tqdm
from IPython.display import clear_output
import warnings
from bert_score import BERTScorer
warnings.filterwarnings("ignore")
from copy import deepcopy
from openai import OpenAI

from src.utils import rotate_half, apply_rotary_pos_emb, repeat_kv, \
    get_context_length, insert_needle

from src.attention_saver import Mistral7BAttentionSaver
from src.influence.influence import Influence

## Importing models and dataset

In [3]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct",
    cache_dir = "/Data"    
)


base_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct",
    quantization_config = quantization_config,
    device_map="auto",
    attn_implementation="eager",
    cache_dir = "/Data"
)


model_name = base_model.config._name_or_path.split("/")[1]

model = Mistral7BAttentionSaver(
    base_model,
    tokenizer,
    delta_attention=0,
    should_save_params= False
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

df = load_dataset("stas/openwebtext-10k", cache_dir="/Data")['train'].to_pandas()
df["text_len"] = df["text"].apply(lambda x: len(x.split(" ")))
df['context_length'] = df['text'].apply(get_context_length, tokenizer = tokenizer)

In [5]:
def split_text_into_chunks(text, token_sizes : List[int], tokenizer):
    # Split the text into sentences
    sentences = text.split('.')
    
    # Remove empty sentences and strip whitespace
    sentences = [sentence.strip() for sentence in sentences if sentence.strip()]

    # Helper function to count tokens
    def count_tokens(text):
        
        return len(tokenizer.encode(text))

    # List to store the chunks
    chunks = {size: '' for size in token_sizes}
    
    for size in (token_sizes):
        current_chunk = []
        current_length = 0
        
        for sentence in sentences:
            sentence_length = count_tokens(sentence)
            
            # If adding the sentence exceeds the token size, start a new chunk
            if current_length + sentence_length + 10> size:
                chunks[size] = '. '.join(current_chunk) + '.'
                
                break
            else:
                current_chunk.append(sentence)
                current_length += sentence_length

        # Add the last chunk if it's not empty
        # if current_chunk:
        #     chunks[size] = '. '.join(current_chunk) + '.'
    
    return chunks

In [6]:
large_text_df = df.query("context_length > 6000")\
    .sample(50, random_state = 33)

samples = []
for idx, row in tqdm(large_text_df.iterrows(), total = len(large_text_df)):
    chunks = split_text_into_chunks(row.text, range(500, 5_000, 500), tokenizer)
    samples.append(pd.Series(chunks))

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [00:02<00:00, 18.24it/s]


In [7]:
study_df = pd.DataFrame(
    pd.concat(samples).reset_index(drop = True)
).rename(columns = {0: 'text'})

study_df['context_length'] = study_df['text']\
    .apply(get_context_length, tokenizer = tokenizer)    

study_df['context_length_bins'] = pd.cut(
    study_df['context_length'],
    range(0, 5000, 500)
)


In [126]:
# chunks = []
# for n in range (9):
#     samples = df.query(f"context_length > {500*n} & context_length < {500*(n+1)}")\
#         .sample(50, random_state = 42)
    
#     chunks.append(samples)

# study_df = pd.concat(chunks)\
#     .sort_values("context_length", ascending = False)

## Adding the needle

In [8]:
needle = "\nThe best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.\n"
question = "What is the best thing to do in San Francisco?"

all_df = []

for depth_percent in tqdm(range(0, 110, 10)):

    percent_df = study_df.apply(
        insert_needle, 
        depth_percent = depth_percent, 
        question = question,
        needle = needle, 
        axis = 1
    )

    all_df.append(percent_df)


100%|██████████| 11/11 [00:04<00:00,  2.38it/s]


In [9]:
needle_in_a_haystack_df = pd.concat(all_df)\
    .reset_index(drop = True)
needle_in_a_haystack_df.head()

Unnamed: 0,text,context_length,context_length_bins,new_text,depth,question,needle
0,beyond markdown\n\nbowerbird Blocked Unblock F...,477,"(0.0, 500.0]",\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
1,beyond markdown\n\nbowerbird Blocked Unblock F...,978,"(500.0, 1000.0]",\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
2,beyond markdown\n\nbowerbird Blocked Unblock F...,1484,"(1000.0, 1500.0]",\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
3,beyond markdown\n\nbowerbird Blocked Unblock F...,1956,"(1500.0, 2000.0]",\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
4,beyond markdown\n\nbowerbird Blocked Unblock F...,2443,"(2000.0, 2500.0]",\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...


In [10]:
needle_in_a_haystack_df

Unnamed: 0,text,context_length,context_length_bins,new_text,depth,question,needle
0,beyond markdown\n\nbowerbird Blocked Unblock F...,477,"(0.0, 500.0]",\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
1,beyond markdown\n\nbowerbird Blocked Unblock F...,978,"(500.0, 1000.0]",\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
2,beyond markdown\n\nbowerbird Blocked Unblock F...,1484,"(1000.0, 1500.0]",\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
3,beyond markdown\n\nbowerbird Blocked Unblock F...,1956,"(1500.0, 2000.0]",\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
4,beyond markdown\n\nbowerbird Blocked Unblock F...,2443,"(2000.0, 2500.0]",\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
...,...,...,...,...,...,...,...
4945,"CF-18, 20-year colors\n\n(click to view full) ...",2475,"(2000.0, 2500.0]",\n<question>\nWhat is the best thing to do in ...,100.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
4946,"CF-18, 20-year colors\n\n(click to view full) ...",2969,"(2500.0, 3000.0]",\n<question>\nWhat is the best thing to do in ...,100.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
4947,"CF-18, 20-year colors\n\n(click to view full) ...",3480,"(3000.0, 3500.0]",\n<question>\nWhat is the best thing to do in ...,100.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
4948,"CF-18, 20-year colors\n\n(click to view full) ...",3969,"(3500.0, 4000.0]",\n<question>\nWhat is the best thing to do in ...,100.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...


## Augmenting attention to the needle

In [11]:
generate_kwargs = {
    'max_new_tokens': 30,
    'max_length': None,
    'num_beams': 1,
    'do_sample': False,
    'temperature': None,
    'top_p': None,
    'top_k': None,
}


In [12]:
needle_in_a_haystack_df

Unnamed: 0,text,context_length,context_length_bins,new_text,depth,question,needle
0,beyond markdown\n\nbowerbird Blocked Unblock F...,477,"(0.0, 500.0]",\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
1,beyond markdown\n\nbowerbird Blocked Unblock F...,978,"(500.0, 1000.0]",\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
2,beyond markdown\n\nbowerbird Blocked Unblock F...,1484,"(1000.0, 1500.0]",\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
3,beyond markdown\n\nbowerbird Blocked Unblock F...,1956,"(1500.0, 2000.0]",\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
4,beyond markdown\n\nbowerbird Blocked Unblock F...,2443,"(2000.0, 2500.0]",\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
...,...,...,...,...,...,...,...
4945,"CF-18, 20-year colors\n\n(click to view full) ...",2475,"(2000.0, 2500.0]",\n<question>\nWhat is the best thing to do in ...,100.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
4946,"CF-18, 20-year colors\n\n(click to view full) ...",2969,"(2500.0, 3000.0]",\n<question>\nWhat is the best thing to do in ...,100.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
4947,"CF-18, 20-year colors\n\n(click to view full) ...",3480,"(3000.0, 3500.0]",\n<question>\nWhat is the best thing to do in ...,100.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
4948,"CF-18, 20-year colors\n\n(click to view full) ...",3969,"(3500.0, 4000.0]",\n<question>\nWhat is the best thing to do in ...,100.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...


In [13]:
decoded = None
DELTA_ATTENTION = 1
    # configure the prompt

generated_texts = []

for i, (idx, row) in enumerate(tqdm(needle_in_a_haystack_df.iterrows())):
    # load dataset

    # Prepare files with predictions, prompt, and generation configurations
    # outfile = Path(outfile)
    # outfile.parent.mkdir(parents=True, exist_ok=True)

    _needle = row['needle']
    _question = row['question']
    input_text = row['new_text']
    context_length = row['context_length']
    

    clear_output()

    message = [{"role": "user", "content": input_text}]
    template = tokenizer.apply_chat_template(message, tokenize = False)

    initial_prompt = template.split(_needle)[0]
    context = template.split(_needle)[1]

    assert (hash(initial_prompt+_needle+context) == hash(template)), "Error in spliting strings. Initial and final string does not match"

    initial_tokens = tokenizer.encode(initial_prompt, return_tensors='pt')
    needle_tokens = tokenizer.encode(_needle, return_tensors='pt')
    context_tokens = tokenizer.encode(context, return_tensors='pt')

    start_idx = initial_tokens.size(1)
    end_idx = start_idx + needle_tokens.size(1) - 1

    model.set_reference_tokens(start_idx, end_idx)
    
    tokens = torch.concat([
        initial_tokens.squeeze(), 
        needle_tokens.squeeze()[1:],
        context_tokens.squeeze()[1:]
    ]).unsqueeze(0)

    q = tokenizer.decode(tokens.squeeze()[start_idx: end_idx])

    assert _needle in q, "Error in tokenization. Not giving attention to correct tokens"

    tokens2 = tokenizer(template, return_tensors='pt')

    assert (abs(tokens.shape[1] - tokens2['input_ids'].shape[1]) <=2 ), "Error in tokenization. Tokens do not match"

    print(f'''
        generating text...
        sample idx = {i}
        context_length = {tokens.shape}
        depth = {row['depth']}
        last generated text = {decoded[0].split("[/INST]") [1] if decoded is not None else 'None'}
        '''
    )

    with torch.no_grad():
        output = model.generate(tokens, **generate_kwargs)

    output = tokenizer.batch_decode(output)
    decoded = output

    generated_texts.append({
        "generated_text": decoded, 
        "target" : _needle, 
        "question": _question, 
        "context_length": context_length, 
        "text_index": idx    
    })

1it [00:03,  3.88s/it]


IndexError: list index out of range

In [14]:
tokens

tensor([[128000, 128000, 128006,    882, 128007,    271,     27,   7998,    397,
           3923,    374,    279,   1888,   3245,    311,    656,    304,   5960,
          13175,   5380,    524,   7998,   1363,    198,    791,   1888,   3245,
            311,    656,    304,   5960,  13175,    374,   8343,    264,  28974,
            323,   2503,    304,  25227,   4692,   5657,    389,    264,  40798,
           1938,    627,  23478,  51594,    271,  58049,  23414,  65096,  77388,
          11359,  23548,   5020,    220,     17,     11,    220,    679,     19,
            271,   2345,    961,    220,     20,    271,    939,   5859,    264,
          37973,    389,  14491,    323,   7247,    271,    576,    374,    961,
            220,     20,    315,    459,  14529,   4101,     13,    499,    649,
           1505,    271,     64,   1160,    315,   7902,    520,    279,   5740,
            315,    279,   4652,    902,    271,  14724,   2167,    499,    311,
            279,   1023,   5

In [None]:
needle_in_a_haystack_df

Unnamed: 0,text,text_len,context_length,new_text,depth,question,needle
4179,Sitting down for the first time with reporters...,2946,4491,\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
3073,Despite Michael Atiyah’s many accolades — he i...,3145,4476,\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
170,Preface\n\nTerminology for 18650 batteries can...,2820,4458,\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
4033,"GTA V Benchmarked Florian Glaser , ✓ Tanja Hin...",1888,4439,\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
6094,"Chapter 39\n\n""I could just take care of both ...",2950,4402,\n<question>\nWhat is the best thing to do in ...,0.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
...,...,...,...,...,...,...,...
5378,Three people from Blount County are facing fel...,125,224,\n<question>\nWhat is the best thing to do in ...,100.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
4115,Russia has begun delivering S-300 air defense ...,128,213,\n<question>\nWhat is the best thing to do in ...,100.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
8040,With the code for Android 5.1 Lollipop now ful...,148,213,\n<question>\nWhat is the best thing to do in ...,100.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...
9178,"Long time ago, all clans of Dwarves lived toge...",125,204,\n<question>\nWhat is the best thing to do in ...,100.0,What is the best thing to do in San Francisco?,\nThe best thing to do in San Francisco is eat...


In [None]:
pd.DataFrame(generated_texts)

Unnamed: 0,generated_text,target,question
0,[<s><s> [INST] \n<question>\nWhat is the best ...,\nThe best thing to do in San Francisco is eat...,What is the best thing to do in San Francisco?
1,[<s><s> [INST] \n<question>\nWhat is the best ...,\nThe best thing to do in San Francisco is eat...,What is the best thing to do in San Francisco?
2,[<s><s> [INST] \n<question>\nWhat is the best ...,\nThe best thing to do in San Francisco is eat...,What is the best thing to do in San Francisco?
3,[<s><s> [INST] \n<question>\nWhat is the best ...,\nThe best thing to do in San Francisco is eat...,What is the best thing to do in San Francisco?
4,[<s><s> [INST] \n<question>\nWhat is the best ...,\nThe best thing to do in San Francisco is eat...,What is the best thing to do in San Francisco?
5,[<s><s> [INST] \n<question>\nWhat is the best ...,\nThe best thing to do in San Francisco is eat...,What is the best thing to do in San Francisco?
6,[<s><s> [INST] \n<question>\nWhat is the best ...,\nThe best thing to do in San Francisco is eat...,What is the best thing to do in San Francisco?


In [None]:
tokens.shape

torch.Size([1, 4564])

In [None]:
tokens2['input_ids'].shape

torch.Size([1, 4562])

In [None]:
q

'\nThe best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.\n'