In [1]:
%load_ext autoreload
%autoreload 2

In [8]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, \
  BitsAndBytesConfig, GPTQConfig
import os

while "notebooks" in os.getcwd():
    os.chdir("..")

import scienceplots
from time import time
from tqdm import tqdm
import torch
from langdetect import detect
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from huggingface_hub import notebook_login
from datasets import load_dataset
import math
from typing import List, Optional, Tuple, Union
from torch import nn
from tqdm import tqdm
from IPython.display import clear_output
import warnings
from bert_score import BERTScorer
warnings.filterwarnings("ignore")
from copy import deepcopy
from openai import OpenAI
import json

from src.utils import rotate_half, apply_rotary_pos_emb, repeat_kv, \
    get_context_length, get_generated_text, FileReader, is_text_in_language, rolling_mean,\
    get_context_length

from src.attention_saver import Mistral7BAttentionSaver
from dotenv import load_dotenv

load_dotenv()

# plt.rc('font', family='serif')
# plt.rc('xtick', labelsize='x-small')
# plt.rc('ytick', labelsize='x-small')
plt.style.use(['science','no-latex'])
tqdm.pandas()

In [3]:
tokenizer = AutoTokenizer.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.1",
    cache_dir = "/Data"    
)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

df = load_dataset("stas/openwebtext-10k", cache_dir= "/Data")['train'].to_pandas()
df["text_len"] = df["text"].apply(lambda x: len(x.split(" ")))

In [5]:
base_instruction = "Summarize in french"
df["context_length"] = (base_instruction + " \n" + df["text"])\
    .progress_apply(get_context_length, tokenizer = tokenizer)

100%|██████████| 10000/10000 [00:19<00:00, 504.12it/s]


In [6]:
base_path = "data/complete_study_200_tokens/checkpoints/"
all_results = []
for delta_attention in [0., 0.5, 1., 2.0, 5.0]:
    for all_layers in ["all", "none"]:
        path = os.path.join(
            base_path,
            f"{all_layers}_layers_generated_delta={delta_attention}.pkl"
        )
        try:
            results_df = pd.read_pickle(path).T

        except Exception as e:
            print(e)
            continue

        parsed_results_dict = dict()

        for epoch in range(len(results_df.columns)-1):
            for (idx, result_epoch) in results_df.loc[:,f"epoch {epoch}"].items():
                s = pd.Series(result_epoch)\
                    .apply(get_generated_text)\

                data = pd.DataFrame(s).T
                data.index = [idx]

                if not epoch in parsed_results_dict:
                    parsed_results_dict[epoch] = []

                parsed_results_dict[epoch].append(data)

            parsed_results_dict[epoch] = pd.concat(parsed_results_dict[epoch])

        all_dfs = []

        for epoch in parsed_results_dict.keys():
            temp_df = pd.melt(
                parsed_results_dict[epoch].reset_index(),
                var_name = "instruction",
                value_name = "generated_text",
                id_vars = "index",
            )

            temp_df["is_french"] = temp_df["generated_text"].apply(is_text_in_language)

            temp_df["generation_epoch"] = epoch

            all_dfs.append(temp_df)

        melted_df = pd.concat(all_dfs)

        melted_df = pd.merge(
            melted_df,
            df[["context_length", "text"]],
            left_on="index",
            right_index=True
        )

        melted_df["context_length_bins"] = pd.cut(
            melted_df["context_length"], 
            np.arange(0,6_500,500)
        )

        melted_df.dropna(inplace=True)

        study_name = f"$\Delta$={delta_attention}"
        
        if all_layers == 'first':
            study_name= f"$\Delta$={delta_attention}, first layer only"

        elif  all_layers == 'all':

            study_name = f"$\Delta$={delta_attention}"
            
        melted_df["study"] = study_name
        
        if delta_attention ==0:
            melted_df["study"] = f"Raw model"

        all_results.append(melted_df)

[Errno 2] No such file or directory: 'data/complete_study_200_tokens/checkpoints/all_layers_generated_delta=0.0.pkl'
Exception raised while analysing the text  21 =85 9/5=4 1 2 2 0 0 0 0 2 2 1 2 2 0 0 4 9 9 9 0 0 0 0 4 9 9 9 0 0 0 0 4 9 9 9 0 0 0 0 3 9 9 9 0 0 0 0 3 9 9 9 0 0 0 0 3 9 9 9 0 0 0 0 3 9 9 9 0 0 0 0 3 9 9 9 0 0 0 0 3 9 9 9 0 0 0 0 4 9 9 9 0 0 0 0
Exception raised while analysing the text  1
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 7
Exception raised while analysing the text  1
2
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 7
Exception raised while analysing the text  ᖺᖷ የስትስት ስትስት የንሕትንም የርንንም ስትስትስት የስትስት ስትስት የስትስት ስትስት የንንም የንሕንስትንም የስትስት የንስትንም የንንም የስትስት የንስትንም የንስትንም የን

In [7]:
pd.concat(all_results).study.unique()

array(['Raw model', '$\\Delta$=0.5', '$\\Delta$=1.0', '$\\Delta$=2.0',
       '$\\Delta$=5.0'], dtype=object)

In [17]:
llm_summaries_df = pd.read_csv("data/summarization_train.csv", index_col= 0)

In [18]:
llm_summaries_df

Unnamed: 0,target,text,text_idx,instruction,prompt
0,Voici un résumé en français du texte sur la pa...,The partition of Quebec refers to the secessio...,134,Summarize in french:,Summarize in french: \nThe partition of Quebec...
1,Voici un résumé en français du texte :\n\n« To...,Everything you know about ARGs is WRONG 22 Dec...,9341,Summarize in french:,Summarize in french: \nEverything you know abo...
2,Voici un résumé en français du texte original ...,Cyrstal Meth Addiction\n\nCrystal meth addicti...,2973,Summarize in french:,Summarize in french: \nCyrstal Meth Addiction\...
3,Voici un résumé en français du texte :\n\nLe m...,Quảng Đức is descriptive of meritorious attrib...,8280,Summarize in french:,Summarize in french: \nQuảng Đức is descriptiv...
4,Voici un résumé de l'article en français :\n\n...,"World Electioneering Entertainment 2016: 1,000...",3921,Summarize in french:,Summarize in french: \nWorld Electioneering En...
...,...,...,...,...,...
235,Voici un résumé du texte en français :\n\nUne ...,A New Zealand firm says it successfully triall...,7394,Summarize in french:,Summarize in french: \nA New Zealand firm says...
236,Voici un résumé du texte en français :\n\nFaiz...,Originally Posted by Faizan Lakhani Originally...,6932,Summarize in french:,Summarize in french: \nOriginally Posted by Fa...
237,Voici un résumé du texte en français :\n\nÀ Da...,"DAVAO CITY, Philippines — Residents who are 18...",3943,Summarize in french:,"Summarize in french: \nDAVAO CITY, Philippines..."
238,LORSQUE LA RHÉTORIQUE DE LA DROITE TOURNE À LA...,WHEN THE RIGHT’S RHETORIC TURNS VIOLENT…. In t...,4922,Summarize in french:,Summarize in french: \nWHEN THE RIGHT’S RHETOR...


In [8]:
results_df = pd.concat(all_results)

In [9]:
results_df['text'].nunique()

240

In [10]:
results_df.query("is_french == 1 & study != 'Raw model'")

Unnamed: 0,index,instruction,generated_text,is_french,generation_epoch,context_length,text,context_length_bins,study
0,134,Summarize in french:,"2015: 2eme referendum\n\nLe 26 novembre 2015,...",True,0,5995,The partition of Quebec refers to the secessio...,"(5500, 6000]",$\Delta$=0.5
18,7890,Summarize in french:,"""Les fousées de la vie"" en français signifie ...",True,0,5598,Hailed as the undisputed queens of ’60s-inspir...,"(5500, 6000]",$\Delta$=0.5
22,6952,Summarize in french:,Venezuela est une république située à l'ouest...,True,0,5446,Americans might be fooled by mass media misinf...,"(5000, 5500]",$\Delta$=0.5
23,2719,Summarize in french:,Je suis prêt à faire tout son mieux pour vous...,True,0,5428,Jahlil Okafor Is On His Way Up High-school bas...,"(5000, 5500]",$\Delta$=0.5
28,760,Summarize in french:,"En tant que français, je pense que la langue ...",True,0,5301,It is easily the most depraved little episode ...,"(5000, 5500]",$\Delta$=0.5
...,...,...,...,...,...,...,...,...,...
714,3023,You must summarize the following text in french:,Summarize this text into French:\nLes trois c...,True,9,302,statigr.am/kimbo_ks13 A study by three scienti...,"(0, 500]",$\Delta$=5.0
715,7394,You must summarize the following text in french:,Bonner de broy et de la cellulose dans les ch...,True,9,256,A New Zealand firm says it successfully triall...,"(0, 500]",$\Delta$=5.0
717,3943,You must summarize the following text in french:,You must summarize the text in French: Les ré...,True,9,235,"DAVAO CITY, Philippines — Residents who are 18...","(0, 500]",$\Delta$=5.0
718,4922,You must summarize the following text in french:,Voilently.... \n\nDans le sillage des fusilla...,True,9,224,WHEN THE RIGHT’S RHETORIC TURNS VIOLENT…. In t...,"(0, 500]",$\Delta$=5.0


In [19]:
eval_df = pd.merge(
    results_df,
    llm_summaries_df[['text_idx', 'target']],
    left_on='index',
    right_on='text_idx'
)

In [11]:
results_df.query("is_french == 1 & study == 'Raw model'").groupby("index")\
    .count()\
    .mean(axis=0)

instruction            8.314286
generated_text         8.314286
is_french              8.314286
generation_epoch       8.314286
context_length         8.314286
text                   8.314286
context_length_bins    8.314286
study                  8.314286
dtype: float64

In [12]:
results_df['context_length'].mean()

2995.8806444143042

In [21]:
TEMPLATE = """
    You are a helpful assistant that ranks summaries. 
    I will provide you a target summary and two summaries (0 and 1) of this text in French .

    You must output which one is a better summary, based on both the quality of the summary and the quality of the French text.

    Here is the target summary: 
    <text>
    {target_summary}
    </text>

    Here is summary 0:
    <0>
    {summary_0}
    </0>

    Here is summary 1:
    <1>
    {summary_1}
    </1>

    Answer in the following format (JSON):
    {{
        "best_summary" : (0 or 1)
        "explaination" : a short explaination why.
    }}
"""

In [63]:

openai = OpenAI(
    api_key=os.environ["DEEP_INFRA_API_KEY"],
    base_url="https://api.deepinfra.com/v1/openai",
)
model_name = "meta-llama/Meta-Llama-3-70B-Instruct"

def winning_rate_llm(
    target_summary : str,
    text_delta : str,
    text_raw : str
):
    
    index_of_delta = int(np.random.random() > 1.2)
    
    shuffling_dict = {
        index_of_delta : text_delta,
        1-index_of_delta: text_raw
    }

    inverse_shuffling_map = {
        index_of_delta : "modified",
        1 - index_of_delta: "raw"
    }

    prompt = TEMPLATE.format(
        target_summary = target_summary,
        summary_0 = shuffling_dict[0],
        summary_1 = shuffling_dict[1]
    )

    

    chat_completion = openai.chat.completions.create(
        model=model_name,
        messages=[{"role": "user", "content": prompt}],
        stream=False,
    )
    
    generated_text = chat_completion\
        .choices[0]\
        .message\
        .content
    prompt_tokens = chat_completion\
        .usage\
        .prompt_tokens
    output_tokens = chat_completion\
        .usage\
        .completion_tokens
    
    try:
        generated_json = json.loads(generated_text)
        generated_json['best_summary'] = inverse_shuffling_map[generated_json['best_summary']]

    except:
        return None
    

    return generated_json

In [64]:
np.random.seed(42)
texts = np.random.choice(eval_df.text_idx.unique(), 30)

mask = eval_df.text_idx.isin(texts)
experiment_df = eval_df[mask]

In [65]:
raw_model_performance = experiment_df.query(f"study == 'Raw model' & is_french == 1 & instruction == 'Summarize in french: '")

In [66]:
# fig, axs = plt.subplots(2, 2, figsize=(16, 12))
study_winning_rates = dict()

for i, study_name in enumerate(experiment_df.study.unique()):
    if study_name == 'Raw model':
        continue

    
    clear_output()
    print(f"comparing french text generated by {study_name} to raw model")
    winning_dict = dict()
    
    augmented_model_performance = experiment_df.query(f"study == '{study_name}' & is_french == 1 & instruction == 'Summarize in french: '")

    for idx, row in tqdm(augmented_model_performance.iterrows(), total = len(augmented_model_performance)):
        text_idx = row['index']
        base_text = row['text']
        target_summary = row['target']
        generated_text_delta = row['generated_text']
        
        raw_model_text_df = raw_model_performance.query(f"index == {text_idx}")

        records = []
        for _, raw_model_row in raw_model_text_df.iterrows():
            generated_text_raw_model = raw_model_row['generated_text']

            generated_json = winning_rate_llm(target_summary, generated_text_delta, generated_text_raw_model)
            records.append(generated_json)

        winning_dict[text_idx] = records
    
    winning_rate_df = pd.DataFrame(pd.DataFrame(dict([ (k,pd.Series(v)) for k,v in winning_dict.items() ])))

    study_winning_rates[study_name] = winning_rate_df

#     j = (i-1) // 2
#     k = (i-1) % 2
#     sns.heatmap(winning_rate_df, cmap='viridis' ,ax= axs[j][k])
#     axs[j][k].set_title(f"Winning matrix of sutdy : {study_name} - winning_rate = {round(winning_rate_df.mean().mean(),2)}")


# fig.subplots_adjust(hspace=0.3)

comparing french text generated by $\Delta$=0.5 to raw model


  0%|          | 0/107 [00:00<?, ?it/s]

In [62]:
winning_rate_llm(target_summary, generated_text_delta, generated_text_raw_model)

{
"best_summary" : 0,
"explaination" : "Summary 0 is totally unrelated to the target summary, but at least it's a coherent text in French. Summary 1 is also unrelated and lacks coherence, with sentences that don't form a logical narrative. Neither summary is a good representation of the target text, but Summary 0 is a better written French text."
}


{'best_summary': 'modified',
 'explaination': "Summary 0 is totally unrelated to the target summary, but at least it's a coherent text in French. Summary 1 is also unrelated and lacks coherence, with sentences that don't form a logical narrative. Neither summary is a good representation of the target text, but Summary 0 is a better written French text."}

In [48]:
augmented_model_performance

Unnamed: 0,index,instruction,generated_text,is_french,generation_epoch,context_length,text,context_length_bins,study,text_idx,target
5110,7686,Summarize in french:,"Le commentaire original était :\n\n""Ce dernie...",True,0,4138,Hardware\n\nI've had the Moto X for five days ...,"(4000, 4500]",$\Delta$=0.5,7686,Voici un résumé du texte en français :\n\nJ'ai...
5138,4762,Summarize in french:,Quality of life (QOL) est un concept général ...,True,0,3527,Quality of life (QOL) is an overarching term f...,"(3500, 4000]",$\Delta$=0.5,4762,La qualité de vie (QOL) est un terme englobant...
5145,1294,Summarize in french:,Summarized French:\n\nNous sommes en même pos...,True,0,3355,"With tonight's win, Florida is now 44-3 and st...","(3000, 3500]",$\Delta$=0.5,1294,Voici un résumé du texte en français :\n\nAvec...
5169,247,Summarize in french:,Le lac Grand Slave (en français : Grand lac d...,True,0,2654,Large lake in the Northwest Territories of Can...,"(2500, 3000]",$\Delta$=0.5,247,Voici un résumé du texte en français :\n\nLe g...
5196,8927,Summarize in french:,Sibyl Moon est un artiste indépendant qui est...,True,0,2067,NetHack\n\nThe first roguelike I played was Ne...,"(2000, 2500]",$\Delta$=0.5,8927,Voici un résumé du texte en français :\n\nNetH...
...,...,...,...,...,...,...,...,...,...,...,...
12183,4487,You must summarize the following text in french:,Here's a possible translation of the text int...,True,9,1319,Artist's illustration of an asteroid that has ...,"(1000, 1500]",$\Delta$=0.5,4487,Voici un résumé du texte en français :\n\n Dan...
12186,6112,You must summarize the following text in french:,L'Occupy Wall Street a besoin d'artistes !\n\...,True,9,1254,Call for Artists: Wall Street Occupennial!\n\n...,"(1000, 1500]",$\Delta$=0.5,6112,Voici un résumé du texte en français :\n\nAppe...
12197,3130,You must summarize the following text in french:,Le ancien agent de renseignement central (CIA...,True,9,859,A former Central Intelligence Agency (CIA) ope...,"(500, 1000]",$\Delta$=0.5,3130,Un ancien opérateur de la Central Intelligence...
12205,8285,You must summarize the following text in french:,Les Browns commencent leur première partie de...,True,9,731,Tom Dahlin/Getty Images\n\nThe Cleveland Brown...,"(500, 1000]",$\Delta$=0.5,8285,Voici un résumé du texte en français :\n\nLes ...


In [44]:
pd.Series(winning_dict)

5110    [{'best_summary': 'raw', 'explaination': 'Summ...
5138    [{'best_summary': 'raw', 'explaination': 'Summ...
5145    [{'best_summary': 'modified', 'explaination': ...
5169    [{'best_summary': 'modified', 'explaination': ...
5196    [{'best_summary': 'raw', 'explaination': 'Summ...
5226    [{'best_summary': 'raw', 'explaination': 'Summ...
5230    [{'best_summary': 'modified', 'explaination': ...
dtype: object

In [47]:
pd.DataFrame(pd.DataFrame(dict([ (k,pd.Series(v)) for k,v in winning_dict.items() ]))).T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
5110,"{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...",,,,
5138,"{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'raw', 'explaination': 'Summa...",,,,,
5145,"{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'raw', 'explanation': 'Summar...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'raw', 'explanation': 'Summar...",,,,
5169,"{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'modified', 'explaination': '...",,,,
5196,"{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...",,"{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...",,
5226,"{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explanation': 'Summar...","{'best_summary': 'raw', 'explanation': 'Summar...",,"{'best_summary': 'raw', 'explanation': 'Summar...","{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'raw', 'explanation': 'Summar...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'raw', 'explaination': 'Summa..."
5230,"{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'modified', 'explaination': '...",,"{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'modified', 'explanation': 'S...","{'best_summary': 'raw', 'explaination': 'Summa...","{'best_summary': 'modified', 'explaination': '...","{'best_summary': 'modified', 'explaination': '...",


In [19]:
raw_model_performance.query(f"index == {text_idx}")

Unnamed: 0,index,instruction,generated_text,is_french,generation_epoch,context_length,text,context_length_bins,study
480,134,You must summarize the following text in french:,L'article partition:\n\nL'article Partition e...,True,1,5995,The partition of Quebec refers to the secessio...,"(5500, 6000]",Raw model
240,134,Important: Summarize in french:,Frédéric Passy. Le Partition. Édition parisie...,True,2,5995,The partition of Quebec refers to the secessio...,"(5500, 6000]",Raw model
0,134,Summarize in french:,"Notes:\n[1] Jean-Louis Chartrand, ""L'influenc...",True,4,5995,The partition of Quebec refers to the secessio...,"(5500, 6000]",Raw model
480,134,You must summarize the following text in french:,Partition de la province de Québec à l'état d...,True,4,5995,The partition of Quebec refers to the secessio...,"(5500, 6000]",Raw model
240,134,Important: Summarize in french:,Here is an answer to your question in French:...,True,6,5995,The partition of Quebec refers to the secessio...,"(5500, 6000]",Raw model
480,134,You must summarize the following text in french:,Pour faire comprendre et traduire ce texte ca...,True,6,5995,The partition of Quebec refers to the secessio...,"(5500, 6000]",Raw model
