In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, \
  BitsAndBytesConfig, GPTQConfig
import os

while "notebooks" in os.getcwd():
    os.chdir("..")

from time import time
from tqdm import tqdm
import torch
from langdetect import detect
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from huggingface_hub import notebook_login
from datasets import load_dataset
import math
from typing import List, Optional, Tuple, Union
from torch import nn
from tqdm import tqdm
from IPython.display import clear_output
import warnings
warnings.filterwarnings("ignore")
from copy import deepcopy
import gc
from sklearn.metrics import roc_auc_score

from src.utils import rotate_half, apply_rotary_pos_emb, repeat_kv, \
    get_context_length, get_generated_text, FileReader, is_text_in_language, rolling_mean

from src.attention_saver import Mistral7BAttentionSaver
from src.influence.influence import Influence, AttentionRollout

import scienceplots
plt.style.use(['science','no-latex'])

tqdm.pandas()

## Loading dataset

In [3]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

df = load_dataset("stas/openwebtext-10k", cache_dir= "/Data")['train'].to_pandas()
df["text_len"] = df["text"].apply(lambda x: len(x.split(" ")))

In [5]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    cache_dir = "/Data"    
)


base_model = AutoModel.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    quantization_config = quantization_config,
    device_map="auto",
    attn_implementation="eager",
    cache_dir = "/Data"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
saver = Mistral7BAttentionSaver(
    base_model,
    tokenizer,
    should_save_params=True
)


## Obtaining generation results

In [7]:
base_instruction = "Summarize in french"
df["context_length"] = (base_instruction + " \n" + df["text"])\
    .progress_apply(get_context_length, tokenizer = tokenizer)

  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [00:20<00:00, 482.23it/s]


In [8]:
base_path = "data/complete_study_200_tokens/checkpoints/"
all_results = []
for delta_attention in [0., 0.5 ,1., 2.0, 5.0]:
    for all_layers in ["all", "none"]:
        path = os.path.join(
            base_path,
            f"{all_layers}_layers_generated_delta={delta_attention}.pkl"
        )
        try:
            results_df = pd.read_pickle(path).T

        except Exception as e:
            print(e)
            continue

        parsed_results_dict = dict()

        for epoch in range(len(results_df.columns)-1):
            for (idx, result_epoch) in results_df.loc[:,f"epoch {epoch}"].items():
                s = pd.Series(result_epoch)\
                    .apply(get_generated_text)\

                data = pd.DataFrame(s).T
                data.index = [idx]

                if not epoch in parsed_results_dict:
                    parsed_results_dict[epoch] = []

                parsed_results_dict[epoch].append(data)

            parsed_results_dict[epoch] = pd.concat(parsed_results_dict[epoch])

        all_dfs = []

        for epoch in parsed_results_dict.keys():
            temp_df = pd.melt(
                parsed_results_dict[epoch].reset_index(),
                var_name = "instruction",
                value_name = "generated_text",
                id_vars = "index",
            )

            temp_df["is_french"] = temp_df["generated_text"].apply(is_text_in_language)

            temp_df["generation_epoch"] = epoch

            all_dfs.append(temp_df)

        melted_df = pd.concat(all_dfs)

        melted_df = pd.merge(
            melted_df,
            df[["context_length", "text"]],
            left_on="index",
            right_index=True
        )

        melted_df["context_length_bins"] = pd.cut(
            melted_df["context_length"], 
            np.arange(0,6_500,500)
        )

        melted_df.dropna(inplace=True)

        study_name = f"$\Delta$={delta_attention}, all_layers={all_layers}"
        
        if all_layers == 'first':
            study_name= f"$\Delta$={delta_attention}, first layer only"

        elif  all_layers == 'all':

            study_name = f"$\Delta$={delta_attention}, all_layers"
        melted_df["study"] = study_name
        
        if delta_attention ==0:
            melted_df["study"] = f"Raw model"

        all_results.append(melted_df)

[Errno 2] No such file or directory: 'data/complete_study_200_tokens/checkpoints/all_layers_generated_delta=0.0.pkl'
Exception raised while analysing the text  21 =85 9/5=4 1 2 2 0 0 0 0 2 2 1 2 2 0 0 4 9 9 9 0 0 0 0 4 9 9 9 0 0 0 0 4 9 9 9 0 0 0 0 3 9 9 9 0 0 0 0 3 9 9 9 0 0 0 0 3 9 9 9 0 0 0 0 3 9 9 9 0 0 0 0 3 9 9 9 0 0 0 0 3 9 9 9 0 0 0 0 4 9 9 9 0 0 0 0
Exception raised while analysing the text  1
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 7
Exception raised while analysing the text  1
2
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 7
Exception raised while analysing the text  ᖺᖷ የስትስት ስትስት የንሕትንም የርንንም ስትስትስት የስትስት ስትስት የስትስት ስትስት የንንም የንሕንስትንም የስትስት የንስትንም የንንም የስትስት የንስትንም የንስትንም የን

## Influence vs probability output in french


In [9]:
probability_in_french = pd.concat(all_results)\
    .query("study == 'Raw model'")\
    .groupby(['index'])\
    .is_french\
    .mean()\
    .reset_index()

In [10]:
unique_text_df = pd.concat(all_results)\
    .groupby(['instruction', 'index'])\
    [['text', "context_length"]]\
    .first()\
    .reset_index()\
    .query(" context_length < 4100")

In [11]:
# del saver
# del base_model
import gc
torch.cuda.empty_cache()
gc.collect()

10

In [39]:
unique_text_df['index']

0        35
4       247
10      408
13      590
14      592
       ... 
713    9683
715    9692
717    9842
718    9965
719    9980
Name: index, Length: 498, dtype: int64

In [45]:
saver.set_delta_attention(0)
raw_attention_dict = dict()

for idx, row in tqdm(unique_text_df.iterrows(), total= len(unique_text_df)):

    saver.reset_internal_parameters()
    instruction = row['instruction']
    text = row['text']
    index = row['index']
    prompt = f"{instruction}\n{text}"

    message = [ {"role": "user", "content": prompt}]

    template = tokenizer.apply_chat_template(
        message,
        tokenize= False
    )

    splits = template.split(instruction)
    initial_prompt = splits[0]
    context = instruction.join(splits[1:])

    assert (hash(initial_prompt+instruction+context) == hash(template)), "Error in spliting strings. Initial and final string does not match"

    initial_tokens = tokenizer.encode(initial_prompt, return_tensors='pt')
    instruction_tokens = tokenizer.encode(
        instruction, 
        return_tensors='pt', 
        add_special_tokens=False
    )
    context_tokens = tokenizer.encode(
        context, 
        return_tensors='pt',
        add_special_tokens=False
    )

    start_idx = initial_tokens.size(1)
    end_idx = start_idx + instruction_tokens.size(1)

    saver.set_reference_tokens(start_idx, end_idx)
    
    tokens = torch.concat([
        initial_tokens.squeeze(), 
        instruction_tokens.squeeze(),
        context_tokens.squeeze()
    ]).unsqueeze(0)

    q = tokenizer.decode(tokens.squeeze()[start_idx: end_idx])

    assert instruction in q, "Error in tokenization. Not giving attention to correct tokens"

    tokens2 = tokenizer(template, return_tensors='pt')

    assert (abs(tokens.shape[1] - tokens2['input_ids'].shape[1]) <=5 ), "Error in tokenization. Tokens do not match"

    with torch.no_grad():
        out = base_model(tokens)

    last_attn_matrix = saver.internal_parameters[-1]\
        ['avg_attention_heads']\
        .squeeze()
    last_token_importances = last_attn_matrix\
        [-1, start_idx:end_idx]\
        .mean()\
        .item()

    raw_attention_dict[(index, instruction)] = last_token_importances


  0%|          | 0/498 [00:00<?, ?it/s]

100%|██████████| 498/498 [11:29<00:00,  1.38s/it]


In [49]:
pd.Series(raw_attention_dict).reset_index()\
    .rename(
        columns = {"level_0" : "index", "level_1" : "instruction", 0:"attention_scores"}
    )\
    .to_parquet("data/influences/raw_attention_scores.parquet")

In [50]:
attention_scores = pd.read_parquet("data/influences/raw_attention_scores.parquet")

In [56]:
results_df = pd.concat(all_results)\
    .query("study == 'Raw model' & context_length < 4100")

In [58]:
attention_probabilities_df = pd.merge(
    results_df,
    attention_scores,
    on=  ['index', 'instruction']
)



In [61]:
proba, target = attention_probabilities_df['attention_scores'], attention_probabilities_df['is_french']

In [63]:
roc_auc_score(target.astype(int), proba)

0.580730678431828

In [64]:
target.corr(proba)

0.1324641803039139