In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, \
  BitsAndBytesConfig, GPTQConfig
import os

while "notebooks" in os.getcwd():
    os.chdir("..")

from time import time
from tqdm import tqdm
import torch
from langdetect import detect
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from huggingface_hub import notebook_login
from datasets import load_dataset
import math
from typing import List, Optional, Tuple, Union
from torch import nn
from tqdm import tqdm
from IPython.display import clear_output
import warnings
warnings.filterwarnings("ignore")
from copy import deepcopy
import gc
from sklearn.metrics import  roc_auc_score, average_precision_score

from src.utils import get_context_length, rolling_mean, \
    convert_to_json, score_json, get_text_whithin_braces  

from src.attention_saver import Mistral7BAttentionSaver
from src.influence.influence import Influence
tqdm.pandas()


import scienceplots
plt.style.use(['science','no-latex', 'grid'])

In [3]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    cache_dir = "/Data"    
)


base_model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    quantization_config = quantization_config,
    device_map="auto",
    attn_implementation="eager",
    cache_dir = "/Data"
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
saver = Mistral7BAttentionSaver(
    base_model,
    tokenizer,
    should_save_params=True
)

In [5]:
dataset = load_dataset("TheBritishLibrary/blbooks", "1510_1699", cache_dir = "/Data")['train']\
    .to_pandas()\
    [["record_id", "title", "text", "pg", "all_names", "Language_1"]]

dataset['context_length'] = dataset['text'].progress_apply(get_context_length, tokenizer=tokenizer)
dataset.record_id.nunique()
all_df = []

for book_id, df in dataset.groupby('record_id'):
    df['text'] = df['text'].cumsum()
    df['context_length'] = df['context_length'].cumsum()

    all_df.append(df)
books = pd.concat(all_df)
np.random.seed(42)
book_ids = np.random.choice(books.record_id.unique(), replace= False, size = 300)

100%|██████████| 51982/51982 [01:18<00:00, 662.33it/s] 


In [6]:
SCHEMA = '''{
    "title": "title of the story (string)", 
    "genre": string, 
    "characters": [{"name": string, "description": string. If not available set it to none} (one dict per character)], 
    "author": "the author of the story. If not available, set it to None", 
    "summary": "a brief summary of the story. Do not write more than 50 words",
    "date": "when the story was released (string)",
    "scenery": "where the story takes place (string)",
}
'''

TEMPLATE = '''
You are an assistant designed to provide information in JSON format. 
I will give you a story, and you need to extract and return specific details from the story. 
Do not output anything else than the JSON.

Your response should follow exactly this template: 

<schema>
{schema}
</schema>

{content}

'''

In [7]:
mask = books.record_id.isin(book_ids)
selected_books = books[mask].query("context_length <4000 & context_length > 500")

In [8]:
selected_books['prompt'] = selected_books\
    ['text']\
    .apply(lambda x: TEMPLATE.format(content = x, schema = SCHEMA))


sentence = "Your response should follow exactly this template:"
selected_books['instruction'] = sentence


# df1= deepcopy(selected_books.query("context_length > 3500").iloc[1])

# selected_books['prompt'] = selected_books\
#     ['text']\
#     .apply(lambda x: TEMPLATE..format(content = x, schema = SCHEMA))

# selected_books['instruction'] = sentence.upper()

# df2 = deepcopy(selected_books.query("context_length > 3500").iloc[1])

# selected_books = pd.concat([df1,df2], axis = 1).T

In [9]:
selected_books

Unnamed: 0,record_id,title,text,pg,all_names,Language_1,context_length,prompt,instruction
5811,000000874,"A Warning to the inhabitants of England, and L...",Q^u'^rtfLspA WARNING T O TH E INHABITANTS OF E...,11,"England [organisation] ; Adams, Mary, active 1...",English,505,\nYou are an assistant designed to provide inf...,Your response should follow exactly this templ...
5812,000000874,"A Warning to the inhabitants of England, and L...",Q^u'^rtfLspA WARNING T O TH E INHABITANTS OF E...,12,"England [organisation] ; Adams, Mary, active 1...",English,985,\nYou are an assistant designed to provide inf...,Your response should follow exactly this templ...
5813,000000874,"A Warning to the inhabitants of England, and L...",Q^u'^rtfLspA WARNING T O TH E INHABITANTS OF E...,13,"England [organisation] ; Adams, Mary, active 1...",English,1433,\nYou are an assistant designed to provide inf...,Your response should follow exactly this templ...
5814,000000874,"A Warning to the inhabitants of England, and L...",Q^u'^rtfLspA WARNING T O TH E INHABITANTS OF E...,14,"England [organisation] ; Adams, Mary, active 1...",English,1763,\nYou are an assistant designed to provide inf...,Your response should follow exactly this templ...
5815,000000874,"A Warning to the inhabitants of England, and L...",Q^u'^rtfLspA WARNING T O TH E INHABITANTS OF E...,15,"England [organisation] ; Adams, Mary, active 1...",English,2105,\nYou are an assistant designed to provide inf...,Your response should follow exactly this templ...
...,...,...,...,...,...,...,...,...,...
30566,004115210,"A Joviall Crew: or, the Merry Beggar. Presente...",A JOVIAL CREW' O % $kt $terrp MBm- A C O M E D...,9,"Brome, Richard, -approximately 1652 [person]",English,924,\nYou are an assistant designed to provide inf...,Your response should follow exactly this templ...
30567,004115210,"A Joviall Crew: or, the Merry Beggar. Presente...",A JOVIAL CREW' O % $kt $terrp MBm- A C O M E D...,10,"Brome, Richard, -approximately 1652 [person]",English,1729,\nYou are an assistant designed to provide inf...,Your response should follow exactly this templ...
30568,004115210,"A Joviall Crew: or, the Merry Beggar. Presente...",A JOVIAL CREW' O % $kt $terrp MBm- A C O M E D...,11,"Brome, Richard, -approximately 1652 [person]",English,2240,\nYou are an assistant designed to provide inf...,Your response should follow exactly this templ...
30569,004115210,"A Joviall Crew: or, the Merry Beggar. Presente...",A JOVIAL CREW' O % $kt $terrp MBm- A C O M E D...,12,"Brome, Richard, -approximately 1652 [person]",English,2963,\nYou are an assistant designed to provide inf...,Your response should follow exactly this templ...


In [10]:
saver.set_delta_attention(0)
raw_attention_dict = dict()

for idx, row in tqdm(selected_books.iterrows(), total= len(selected_books)):

    saver.reset_internal_parameters()
    instruction = row['instruction']
    prompt = row['prompt']

    message = [ {"role": "user", "content": prompt}]

    template = tokenizer.apply_chat_template(
        message,
        tokenize= False
    )

    splits = template.split(instruction)
    initial_prompt = splits[0]
    context = instruction.join(splits[1:])

    assert (hash(initial_prompt+instruction+context) == hash(template)), "Error in spliting strings. Initial and final string does not match"

    initial_tokens = tokenizer.encode(initial_prompt, return_tensors='pt')
    instruction_tokens = tokenizer.encode(
        instruction, 
        return_tensors='pt', 
        add_special_tokens=False
    )
    context_tokens = tokenizer.encode(
        context, 
        return_tensors='pt',
        add_special_tokens=False
    )

    start_idx = initial_tokens.size(1)
    end_idx = start_idx + instruction_tokens.size(1)

    saver.set_reference_tokens(start_idx, end_idx)
    
    tokens = torch.concat([
        initial_tokens.squeeze(), 
        instruction_tokens.squeeze(),
        context_tokens.squeeze()
    ]).unsqueeze(0)

    q = tokenizer.decode(tokens.squeeze()[start_idx: end_idx])

    assert instruction in q, "Error in tokenization. Not giving attention to correct tokens"

    tokens2 = tokenizer(template, return_tensors='pt')

    assert (abs(tokens.shape[1] - tokens2['input_ids'].shape[1]) <=5 ), "Error in tokenization. Tokens do not match"

    with torch.no_grad():
        out = base_model(tokens)

    last_attn_matrix = saver.internal_parameters[-1]\
        ['avg_attention_heads']\
        .squeeze()
    last_token_importances = last_attn_matrix\
        [-1, start_idx:end_idx]\
        .mean()\
        .item()

    raw_attention_dict[(idx)] = last_token_importances

  0%|          | 0/2246 [00:00<?, ?it/s]

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
100%|██████████| 2246/2246 [59:14<00:00,  1.58s/it] 


In [23]:
pd.Series(raw_attention_dict).to_pickle("data/influences/json/raw_attention_scores.pkl")

In [3]:
raw_attention= pd.read_pickle("data/influences/json/raw_attention_scores.pkl")

In [4]:
raw_attention

5811     0.002363
5812     0.001489
5813     0.001683
5814     0.001068
5815     0.001007
           ...   
30566    0.001670
30567    0.001093
30568    0.001327
30569    0.000930
30570    0.001114
Length: 2246, dtype: float64

In [5]:
generated_text = pd.read_parquet("data/study-04-json/05/delta=0.0.parquet")

generated_text['generated_text'] = generated_text['generated_text']\
        .apply(lambda x: x.split('[/INST]')[1])
        
generated_text['generated_braces'] = generated_text['generated_text']\
    .apply(get_text_whithin_braces)

generated_text['generated_json'] = generated_text['generated_braces'].apply(convert_to_json)

generated_text['score'] = generated_text['generated_json'].apply(score_json)


In [6]:
generated_text

Unnamed: 0,generated_text,original_text,schema,context_length,book_id,generated_braces,generated_json,score
33443,Title: The Fourth Collection of Poems\n\nAuth...,THE Fourth ( and Laft ) COLLECTION O F g>atpt&...,"{\n ""title"": ""title of the story (string)"",...",4499,000744786,,,0.0
18898,Title: A Comedy\n\nAuthor: John Fletcher\n\nG...,MONSIE VR THOMAS A COMEDY Acted at the Private...,"{\n ""title"": ""title of the story (string)"",...",4499,001253024,,,0.0
79,Title: The Spanish Rogue\n\nAuthor: Anonymous...,THE SPANISH ROGUR As it was A C T E D B Y H I ...,"{\n ""title"": ""title of the story (string)"",...",4493,000997538,,,0.0
6236,The given text is a prologue to a play called...,aTHE EMPEROVR OF T H E E A S T A Tragæ-Comœdfe...,"{\n ""title"": ""title of the story (string)"",...",4490,002417003,,,0.0
42627,"The poem ""A Poem on the Birth of the Prince"" ...",r Britannia Redi viva : A P O E M O N T H E O ...,"{\n ""title"": ""title of the story (string)"",...",4488,000987688,,,0.0
...,...,...,...,...,...,...,...,...
26838,"<schema>\n{\n""title"": ""The Marriages of the J...",TEX*ftorAMlA. or *'* fHE MARRIAGES OF THE zJ^T...,"{\n ""title"": ""title of the story (string)"",...",506,001720349,"{\n""title"": ""The Marriages of the Jests"",\n""ge...","{'title': 'The Marriages of the Jests', 'genre...",1.0
5811,"<schema>\n{\n""title"": ""A WARNING TO THE INHAB...",Q^u'^rtfLspA WARNING T O TH E INHABITANTS OF E...,"{\n ""title"": ""title of the story (string)"",...",505,000000874,"{\n""title"": ""A WARNING TO THE INHABITANTS OF E...",{'title': 'A WARNING TO THE INHABITANTS OF ENG...,1.0
47662,"<schema>\n{\n""title"": ""The Trojan War"",\n""gen...","FVIMVS TROES JEneid. z. THE TRVE TROIANES, Bei...","{\n ""title"": ""title of the story (string)"",...",503,003678244,"{\n""title"": ""The Trojan War"",\n""genre"": ""Epic""...","{'title': 'The Trojan War', 'genre': 'Epic', '...",1.0
49176,"<schema>\n{\n""title"": ""The VViddovvesTearcs"",...",THE VViddovvesTearcs *a Comedie. A$ it was oft...,"{\n ""title"": ""title of the story (string)"",...",502,000660920,"{\n""title"": ""The VViddovvesTearcs"",\n""genre"": ...","{'title': 'The VViddovvesTearcs', 'genre': 'Co...",1.0


In [11]:
attention_scores_df = pd.merge(
    generated_text,
    raw_attention.rename("attention_scores"),
    left_index=True,
    right_index=True,
    how = "inner"
)

In [12]:

attention_scores_df

Unnamed: 0,generated_text,original_text,schema,context_length,book_id,generated_braces,generated_json,score,attention_scores
46893,"<schema>\n{\n""title"": ""The Princefs of Parma""...","f!; ' 'V^- d »* -w . , THE Princefsof^ Parma. ...","{\n ""title"": ""title of the story (string)"",...",3997,003418042,"{\n""title"": ""The Princefs of Parma"",\n""genre"":...",,0.000000,0.001390
3841,"<schema>\n{\n""title"": ""El Quarto de la Fortun...",V E RDADERO ENTRE T E NIMl E N T O del Chrifti...,"{\n ""title"": ""title of the story (string)"",...",3995,002263347,"{\n""title"": ""El Quarto de la Fortuna"",\n""genre...",,0.000000,0.000740
34764,"<schema>\n{\n""title"": ""A View of Religion"",\n...","AN ENQ.UIRY AFTER RELIGION: OR, A View of thtt...","{\n ""title"": ""title of the story (string)"",...",3994,001145182,"{\n""title"": ""A View of Religion"",\n""genre"": ""R...",,0.000000,0.000975
19174,"<schema>\n{\n""title"": ""The Elder Brother"",\n""...",THE ELDER BROTHER A COMEDIE. A#ed at the TSlac...,"{\n ""title"": ""title of the story (string)"",...",3994,001253007,"{\n""title"": ""The Elder Brother"",\n""genre"": ""co...","{'title': 'The Elder Brother', 'genre': 'comed...",0.714286,0.000771
23159,"<schema>\n{\n""title"": ""The Wife-woman of Hogf...","The V Vife- woman » .1 i Of HOQJDOK, A COMEDIE...","{\n ""title"": ""title of the story (string)"",...",3992,001677342,"{\n""title"": ""The Wife-woman of Hogfdon"",\n""gen...","{'title': 'The Wife-woman of Hogfdon', 'genre'...",0.714286,0.000995
...,...,...,...,...,...,...,...,...,...
26838,"<schema>\n{\n""title"": ""The Marriages of the J...",TEX*ftorAMlA. or *'* fHE MARRIAGES OF THE zJ^T...,"{\n ""title"": ""title of the story (string)"",...",506,001720349,"{\n""title"": ""The Marriages of the Jests"",\n""ge...","{'title': 'The Marriages of the Jests', 'genre...",1.000000,0.002249
5811,"<schema>\n{\n""title"": ""A WARNING TO THE INHAB...",Q^u'^rtfLspA WARNING T O TH E INHABITANTS OF E...,"{\n ""title"": ""title of the story (string)"",...",505,000000874,"{\n""title"": ""A WARNING TO THE INHABITANTS OF E...",{'title': 'A WARNING TO THE INHABITANTS OF ENG...,1.000000,0.002363
47662,"<schema>\n{\n""title"": ""The Trojan War"",\n""gen...","FVIMVS TROES JEneid. z. THE TRVE TROIANES, Bei...","{\n ""title"": ""title of the story (string)"",...",503,003678244,"{\n""title"": ""The Trojan War"",\n""genre"": ""Epic""...","{'title': 'The Trojan War', 'genre': 'Epic', '...",1.000000,0.002438
49176,"<schema>\n{\n""title"": ""The VViddovvesTearcs"",...",THE VViddovvesTearcs *a Comedie. A$ it was oft...,"{\n ""title"": ""title of the story (string)"",...",502,000660920,"{\n""title"": ""The VViddovvesTearcs"",\n""genre"": ...","{'title': 'The VViddovvesTearcs', 'genre': 'Co...",1.000000,0.002283


In [25]:
target, probas = attention_scores_df['score'] > 0, attention_scores_df['attention_scores']

In [26]:
roc_auc_score(target.astype(int), probas)

0.6353125447149486

In [27]:
target.corr(probas)

0.23382698844492394