In [10]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np

from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import accuracy_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ProbeNN(nn.Module):
    def __init__(self, input_dim):
        super(ProbeNN, self).__init__()
        self.layer1 = nn.Linear(input_dim, 256)
        self.layer2 = nn.Linear(256, 128)
        self.layer3 = nn.Linear(128, 64)
        self.output = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = torch.relu(self.layer2(x))
        x = torch.relu(self.layer3(x))
        x = torch.sigmoid(self.output(x))
        return x



def evaluate_unbalanced_testset(test_pred_probs, df_test, accuracy_threshold):
    df_pred = df_test.copy()
    df_pred['pred_prob'] = test_pred_probs
    df_pred['binary_pred'] = df_pred['pred_prob'].apply(lambda x: 1 if x > 0.5 else 0)

    # Initialize counters in a dictionary
    collect_counts = {'collect1s': 0, 'collect0s': 0}

    # Define the evaluation function for each group
    def evaluate_group(group, counts):
        if group['label_mini_fact'].sum() == len(group):  # All labels are 1
            counts['collect1s'] += 1
            if group['binary_pred'].sum() == len(group):  # All predictions must be 1
                return 'correct'
            else:
                return 'incorrect'
        else:  # At least one label is 0
            counts['collect0s'] += 1
            if (group['binary_pred'] == 0).any():  # At least one prediction must be 0
                return 'correct'
            else:
                return 'incorrect'

    # Apply the evaluation function to each group
    grouped_predictions = df_pred.groupby('gen_sentence').apply(lambda grp: evaluate_group(grp, collect_counts)).reset_index(name='group_prediction')
    
    num_correct = (grouped_predictions['group_prediction'] == 'correct').sum()
    accuracy = num_correct / len(grouped_predictions)

    # Calculate the AUC-ROC score
    df_grouped = df_pred.groupby('gen_sentence').agg(
        true_group_label=('label_mini_fact', lambda x: 1 if x.sum() == len(x) else 0),
        pred_group_prob=('pred_prob', 'min')
    ).reset_index()
    
    auc_roc_score = roc_auc_score(df_grouped['true_group_label'], df_grouped['pred_group_prob'])
    return accuracy, collect_counts['collect1s'], collect_counts['collect0s'], auc_roc_score



def compute_roc_curve(test_labels, test_pred_prob):
    fpr, tpr, _ = roc_curve(test_labels, test_pred_prob)
    roc_auc = auc(fpr, tpr)
    return roc_auc, fpr, tpr 


def get_results(df_test, model_path, layer, probe_method, accuracy_threshold):
    test_embeddings = np.array(df_test[f'embeddings{layer}_{probe_method}'].tolist())
    test_labels = df_test[f'label_{probe_method}']

    model = ProbeNN(test_embeddings.shape[1]).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    with torch.no_grad():
        test_pred_prob = model(torch.tensor(test_embeddings, dtype=torch.float32).to(device)).cpu().numpy()

    test_accuracy = accuracy_score(test_labels, test_pred_prob > accuracy_threshold)
    roc_auc, fpr, tpr = compute_roc_curve(test_labels, test_pred_prob)
    return test_pred_prob, test_accuracy, roc_auc


In [19]:
results_fever_test = []
results_hover_test = []


train_datasets = ["hover"]
test_datasets = ["hover"]

model_name = "llama"
layers = [-1]
balanced = False

#caption = "\\textbf{AUROC} across different layers for FEVER and HoVer test datasets"
#caption = "\\textbf{AUROC} scores across different layers for cross-testing: probes trained on HoVer are tested on FEVER, and probes trained on FEVER are tested on HoVer."
#caption = "\\textbf{AUROC} scores for probes trained on Llama3-8B embeddings and tested on Phi generations"

#label = "tab:auroc_probes"
#label = "tab:auroc_probes_cross_test" 
#label = "tab:auroc_probes_phi"



def balance_dataframe(df, label_name):
    df_label_1 = df[df[str(label_name)] == 1]
    df_label_0 = df[df[str(label_name)] == 0]
    min_class_count = min(len(df_label_1), len(df_label_0))
    df_label_1_downsampled = df_label_1.sample(min_class_count, random_state=42)
    df_label_0_downsampled = df_label_0.sample(min_class_count, random_state=42)
    balanced_df = pd.concat([df_label_1_downsampled, df_label_0_downsampled])
    return balanced_df.reset_index(drop=True)

for test_dataset, train_dataset in zip(test_datasets, train_datasets):
    for layer in layers:
        print(f"Layer: {layer}")
        layer_name = (
            "First \\newline Hidden \\newline Layer" if layer == 1 else
            "Hidden \\newline Layer: \\newline 9" if layer == -24 else
            "Hidden \\newline Layer: \\newline 17" if layer == -16 else
            "Hidden \\newline Layer: \\newline 25" if layer == -8 else
            "Last \\newline Hidden \\newline Layer" if layer == -1 else None
        )
        if model_name == "llama":
            test_dataset_name = f"processed_datasets_{model_name}_{test_dataset}_layer{layer}"
            if balanced:
                df_test_mini_fact = pd.read_pickle(f"./{test_dataset_name}/mini_fact_{test_dataset}_test_balanced.pkl")
                df_test_sentence = pd.read_pickle(f"./{test_dataset_name}/sentence_{test_dataset}_test_balanced.pkl")
            else:
                df_test_mini_fact = pd.read_pickle(f"./{test_dataset_name}/mini_fact_{test_dataset}_test_unbalanced.pkl")
                df_test_sentence = pd.read_pickle(f"./{test_dataset_name}/sentence_{test_dataset}_test_unbalanced.pkl")
                df_test_sentence = balance_dataframe(df_test_sentence, 'label_sentence')
                df_test_mini_fact = df_test_mini_fact[df_test_mini_fact['gen_sentence'].isin(df_test_sentence['output_sentence'])]


        if model_name == "phi":
            test_dataset_name = f"processed_datasets_{model_name}"
            df_test_mini_fact = pd.read_pickle(f"./{test_dataset_name}/mini_fact_{test_dataset}_layer{layer}_test_unbalanced.pkl")
            df_test_sentence = pd.read_pickle(f"./{test_dataset_name}/sentence_{test_dataset}_layer{layer}_test_unbalanced.pkl")
            df_test_sentence = balance_dataframe(df_test_sentence, 'label_sentence')
            df_test_mini_fact = df_test_mini_fact[df_test_mini_fact['gen_sentence'].isin(df_test_sentence['output_sentence'])]
            
        model_path_sentence = f"./probes/sentence_embeddings{layer}_{train_dataset}.pth"
        model_path_mini_fact = f"./probes/mini_fact_embeddings{layer}_{train_dataset}.pth"



        test_pred_probs_sentences, test_accuracy_sentences, roc_auc_sentences = get_results(df_test_sentence, 
                                                                                            model_path_sentence, 
                                                                                            layer=layer, 
                                                                                            probe_method="sentence", accuracy_threshold=0.5)
        
        test_pred_probs_mini_fact, test_accuracy_mini_fact, roc_auc_mini_fact = get_results(df_test_mini_fact, 
                                                                                                    model_path_mini_fact, 
                                                                                                    layer=layer, 
                                                                                                    probe_method="mini_fact", accuracy_threshold=0.5)
        
        df_test_mini_fact['pred_prob'] = test_pred_probs_mini_fact
        df_test_mini_fact['binary_pred'] = df_test_mini_fact['pred_prob'].apply(lambda x: 1 if x > 0.5 else 0)
        df_test_mini_fact.to_pickle(f"prediction.pkl")

        if not balanced:
            correct_ratio, collect1s, collect0s, auc_roc_score_mini_facts = evaluate_unbalanced_testset(test_pred_probs_mini_fact, df_test_mini_fact, accuracy_threshold=0.5)

        if test_dataset == "fever":
            print(f"Test dataset: {test_dataset}")
            results_fever_test.append({
                "train_dataset": train_dataset,
                "test_dataset": test_dataset,
                "layer": layer_name,
                "accuracySentence": test_accuracy_sentences if balanced else None,
                "auc_sentences": roc_auc_sentences,
                "accuracyMiniFacts": test_accuracy_mini_fact if balanced else None,
                "auc_mini_facts": roc_auc_mini_fact,
                "auc_mini_facts_sentences_match": auc_roc_score_mini_facts if not balanced else None
            })
        elif test_dataset == "hover":
            print(f"Test dataset: {test_dataset}")
            results_hover_test.append({
                "train_dataset": train_dataset,
                "test_dataset": test_dataset,
                "layer": layer_name,
                "accuracySentence": test_accuracy_sentences if balanced else None,
                "auc_sentences": roc_auc_sentences,
                "accuracyMiniFacts": test_accuracy_mini_fact if balanced else None,
                "auc_mini_facts": roc_auc_mini_fact,
                "auc_mini_facts_sentences_match": auc_roc_score_mini_facts if not balanced else None
            })

Layer: -1
Test dataset: hover


In [20]:
results_hover_test

[{'train_dataset': 'hover',
  'test_dataset': 'hover',
  'layer': 'Last \\newline Hidden \\newline Layer',
  'accuracySentence': None,
  'auc_sentences': 0.7846093400217464,
  'accuracyMiniFacts': None,
  'auc_mini_facts': 0.8162134434837778,
  'auc_mini_facts_sentences_match': 0.8105132352000586}]

In [103]:
import numpy as np
import pandas as pd

# Adjusted layers_order to match formatting
layers_order = [
    "First \\newline Hidden \\newline Layer",
    "Hidden \\newline Layer: \\newline 9",
    "Hidden \\newline Layer: \\newline 17",
    "Hidden \\newline Layer: \\newline 25",
    "Last \\newline Hidden \\newline Layer"
]


# Function to prepare data
def prepare_data(df, value_vars, index_name):
    df_melted = df.melt(
        id_vars=["train_dataset", "test_dataset", "layer"],
        value_vars=value_vars,
        var_name=index_name,
        value_name="Value",
    )
    df_pivot = df_melted.pivot_table(
        index=[index_name, "train_dataset", "test_dataset"],
        columns="layer",
        values="Value",
    )

    # Reorder columns based on layers_order
    df_pivot = df_pivot.reindex(columns=layers_order, fill_value=np.nan)

    df_pivot.columns.name = None  # Remove multi-index for columns
    df_pivot.reset_index(inplace=True)
    return df_pivot


def generate_table_lines(train_dataset_name, test_dataset_name, df_pivot, metrics, metric_names):
    lines = []

    for idx, metric in enumerate(metrics):
        metric_name = metric_names[idx]
        train_test_label = (
            f"\\scriptsize{{\\textbf{{{train_dataset_name}}}}} \\newline \\scriptsize{{\\textbf{{{test_dataset_name}}}}}"
            if idx == 0
            else ""
        )
        row_values = []

        # Get the maximum value for the metric across layers
        max_value_row = max(df_pivot.loc[df_pivot[df_pivot.columns[0]] == metric].iloc[0, 3:], default=np.nan)

        for layer in layers_order:
            value = df_pivot.loc[df_pivot[df_pivot.columns[0]] == metric, layer].values[0]
            formatted_value = f"{value:.3f}" if not np.isnan(value) else "nan"
            if np.isclose(float(value), max_value_row, equal_nan=False):  # Highlight max value
                formatted_value = f"\\textbf{{{formatted_value}}}"
            row_values.append(formatted_value)

        # Construct row
        row = [train_test_label, f"\\small{{{metric_name}}}"] + row_values
        lines.append(" & ".join(row) + " \\\\")
        if metric == metrics[-1]:
            lines.append("\\addlinespace[5pt]")  # Add space after last metric
    return lines


# Function to create a complete LaTeX table
def create_latex_table(df_pivot_fever, df_pivot_hover, metrics, metric_names, caption, label):
    header = ["Train/Test Set", "Probe"] + layers_order
    header_line = " & ".join(header) + " \\\\ \\hline"

    table_lines = [header_line]
    

    
    # FEVER dataset 
    #train_dataset_name = "FEVER (Llama3-8B)" if model_name == "phi" else "FEVER"
    #test_dataset_name = "FEVER (Phi)" if model_name == "phi" else "FEVER"
    train_dataset_name =  "FEVER" if df_pivot_fever['train_dataset'].tolist()[0]=="fever" else "HoVer"
    test_dataset_name = "FEVER" if df_pivot_fever['test_dataset'].tolist()[0]=="fever" else "HoVer"
    if model_name == "phi":
        train_dataset_name = train_dataset_name + f" (Llama3-8B) /"
        test_dataset_name = test_dataset_name + f" (Phi)"
    else:
        train_dataset_name = train_dataset_name + f" /"


    table_lines.extend(generate_table_lines(train_dataset_name, test_dataset_name, df_pivot_fever, metrics, metric_names))

    table_lines.append("\\hline")
    table_lines.append("\\addlinespace[5pt]")

    # HoVer dataset
    train_dataset_name = "FEVER" if df_pivot_hover['train_dataset'].tolist()[0]=="fever" else "HoVer"
    test_dataset_name = "FEVER" if df_pivot_hover['test_dataset'].tolist()[0]=="fever" else "HoVer"
    if model_name == "phi":
        train_dataset_name = train_dataset_name + f" (Llama3-8B) /"
        test_dataset_name = test_dataset_name + f" (Phi)"
    else:
        train_dataset_name = train_dataset_name + f" /"


    table_lines.extend(generate_table_lines(train_dataset_name, test_dataset_name, df_pivot_hover, metrics, metric_names))

    table_lines.append("\\hline")

    # Assemble the LaTeX table
    latex_table = [
        "\\begin{table}[h]",
        "\\centering",
        "\\small",
        "\\renewcommand{\\arraystretch}{1.2}",
        "\\setlength{\\tabcolsep}{3pt}",
        "\\begin{tabular}{p{3cm} p{3.5cm} @{{\\hskip 0.5cm}} p{1.5cm} p{1.5cm} p{1.5cm} p{1.5cm} p{1.5cm} }",
        "\\hline",
    ]
    latex_table.extend(table_lines)
    latex_table.append("\\end{tabular}")
    latex_table.append(f"\\caption{{{caption}}}")
    latex_table.append(f"\\label{{{label}}}")
    latex_table.append("\\end{table}")

    return "\n".join(latex_table)


fever_df = pd.DataFrame(results_fever_test)
hover_df = pd.DataFrame(results_hover_test)

if balanced:
    fever_accuracy_pivot = prepare_data(fever_df, ['accuracySentence', 'accuracyMiniFacts'], 'Accuracy')
    hover_accuracy_pivot = prepare_data(hover_df, ['accuracySentence', 'accuracyMiniFacts'], 'Accuracy')

if balanced:
    # Create and print the Accuracy table
    metrics_accuracy = ['accuracyMiniFacts', 'accuracySentence']
    metric_names_accuracy = ['Mini Facts', 'Sentence']
    latex_table_accuracy = create_latex_table(
        fever_accuracy_pivot, hover_accuracy_pivot,
        metrics_accuracy, metric_names_accuracy,
        '\\textbf{Accuracy} Across Different Layers for balanced FEVER and HoVer test datasets',
        'tab:accuracy_probes'
    )
    print(latex_table_accuracy)


else:

    fever_auroc_pivot = prepare_data(fever_df, ["auc_sentences", "auc_mini_facts", "auc_mini_facts_sentences_match"], "AUROC")
    hover_auroc_pivot = prepare_data(hover_df, ["auc_sentences", "auc_mini_facts", "auc_mini_facts_sentences_match"], "AUROC")

    # Generate the AUROC table
    metrics_auc = ["auc_mini_facts", "auc_mini_facts_sentences_match", "auc_sentences"]
    metric_names_auc = ["Mini Facts -> Mini Facts Level", "Mini Facts -> Sentence Level", "Sentences"]

    latex_table_auc = create_latex_table(
        fever_auroc_pivot,
        hover_auroc_pivot,
        metrics_auc,
        metric_names_auc,
        caption,
        label
    )

    print(latex_table_auc)


\begin{table}[h]
\centering
\small
\renewcommand{\arraystretch}{1.2}
\setlength{\tabcolsep}{3pt}
\begin{tabular}{p{3cm} p{3.5cm} @{{\hskip 0.5cm}} p{1.5cm} p{1.5cm} p{1.5cm} p{1.5cm} p{1.5cm} }
\hline
Train/Test Set & Probe & First \newline Hidden \newline Layer & Hidden \newline Layer: \newline 9 & Hidden \newline Layer: \newline 17 & Hidden \newline Layer: \newline 25 & Last \newline Hidden \newline Layer \\ \hline
\scriptsize{\textbf{FEVER (Llama3-8B) /}} \newline \scriptsize{\textbf{FEVER (Phi)}} & \small{Mini Facts -> Mini Facts Level} & 0.716 & 0.827 & \textbf{0.865} & 0.842 & 0.834 \\
 & \small{Mini Facts -> Sentence Level} & 0.732 & 0.835 & \textbf{0.870} & 0.853 & 0.846 \\
 & \small{Sentences} & 0.659 & 0.814 & \textbf{0.832} & 0.824 & 0.827 \\
\addlinespace[5pt]
\hline
\addlinespace[5pt]
\scriptsize{\textbf{HoVer (Llama3-8B) /}} \newline \scriptsize{\textbf{HoVer (Phi)}} & \small{Mini Facts -> Mini Facts Level} & 0.680 & 0.804 & \textbf{0.831} & 0.816 & 0.805 \\
 & \small{Min

Tests

In [4]:
grouped_mini_fact = df_test_mini_fact.groupby('gen_sentence')

for name, group in grouped_mini_fact:
    print(name)
    print(group['output_mini_fact'].tolist())
    print(group['label_mini_fact'].tolist())
    print("Probs: ", group['pred_prob'].tolist())
    sentence_label = df_test_sentence.loc[df_test_sentence['output_sentence'] == name, 'label_sentence'].values[0]
    print(sentence_label)
    print("\n")

"Funnybot" is the 15th episode of the fourteenth season of South Park, an American animated television series created by Trey Parker and Matt Stone.
['South Park is an American animated television series.', 'South Park was created by Trey Parker and Matt Stone.']
[1, 1]
Probs:  [0.9734926819801331, 0.9741986393928528]
1


"Funnybot" is the 15th episode of the fourteenth season of South Park.
['Funnybot is the 15th episode of the fourteenth season of South Park.']
[0]
Probs:  [0.26021644473075867]
0


"Get Free" is a song by The Vines from The Vines' debut album, Highly Evolved.
['Get Free is a song by The Vines from their debut album, Highly Evolved.']
[1]
Probs:  [0.23341509699821472]
1


"I Only Want to Be with You" was also recorded by Texas Instruments on Texas Instruments' album Continued Story with Texas Instruments.
['I Only Want to Be with You was also recorded by Texas Instruments on their album Continued Story with Texas Instruments']
[0]
Probs:  [0.11228354275226593]
0


"I'

In [5]:
test_pred_probs_mini_fact, test_accuracy_mini_fact, roc_auc_mini_fact = get_results(df_test_mini_fact, 
                                                                                    model_path_mini_fact, 
                                                                                    layer=layer, 
                                                                                    probe_method="mini_fact", accuracy_threshold=0.5)

df_test_mini_fact['pred_prob'] = test_pred_probs_mini_fact
df_test_mini_fact['binary_pred'] = df_test_mini_fact['pred_prob'].apply(lambda x: 1 if x > 0.5 else 0)
true_count = 0

for name, group in df_test_mini_fact.groupby('gen_sentence'):
    print(name)
    print(group['label_mini_fact'].values)

    #print(group['pred_prob'].values)
    #print(group['binary_pred'].values)
    
    if 0 in group['label_mini_fact'].values:
        if 0 in group['binary_pred'].values:
            print("True")
            true_count += 1
        else:
            print("False")
    else:
        if 0 in group['binary_pred'].values:
            print("False")
        else:
            print("True")
            true_count += 1
    
    gen_sentence = df_test_sentence[df_test_sentence['output_sentence'] == name]['output_sentence'].values[0]
    label_sentence = df_test_sentence[df_test_sentence['output_sentence'] == name]['label_sentence'].values[0]
    print(gen_sentence)
    print(label_sentence)
    print("###")

true_count / len(df_test_mini_fact['gen_sentence'].unique())

"Funnybot" is the 15th episode of the fourteenth season of South Park, an American animated television series created by Trey Parker and Matt Stone.
[1 1]
True
"Funnybot" is the 15th episode of the fourteenth season of South Park, an American animated television series created by Trey Parker and Matt Stone.
1
###
"Funnybot" is the 15th episode of the fourteenth season of South Park.
[0]
True
"Funnybot" is the 15th episode of the fourteenth season of South Park.
0
###
"Get Free" is a song by The Vines from The Vines' debut album, Highly Evolved.
[1]
False
"Get Free" is a song by The Vines from The Vines' debut album, Highly Evolved.
1
###
"I Only Want to Be with You" was also recorded by Texas Instruments on Texas Instruments' album Continued Story with Texas Instruments.
[0]
True
"I Only Want to Be with You" was also recorded by Texas Instruments on Texas Instruments' album Continued Story with Texas Instruments.
0
###
"I'll Be There for You" is a song written by Donna Weiss and Jackie

0.7482758620689656

In [None]:
from itertools import chain
import pandas as pd

df_train = pd.read_pickle(f"./processed_datasets_with_bart_hover_layer-24/mini_fact_hover_train.pkl")
df_test = pd.read_pickle(f"./processed_datasets_with_bart_hover_layer-24/mini_fact_hover_test_unbalanced.pkl")

In [None]:
import pandas as pd
from transformers import logging
import spacy
from fastcoref import spacy_component
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers.cross_encoder import CrossEncoder
from transformers import AutoTokenizer
import torch
import os

logging.set_verbosity_error()


class DatasetBuilder:
    def __init__(self, nlp_processor, vectorizer, gen_evidence_file, sentence_file):
        self.vectorizer = vectorizer
        self.nlp_processor = nlp_processor 
        self.gen_evidence_file = gen_evidence_file
        self.sentence_file = sentence_file

    def match_tokens(self, start_index, token_list, sentence):
        sentence_list = "".join(sentence)
        print(sentence)
        #sentence_list = sentence.split()
        len_sentence_list = len(sentence_list)
        matched_probs = []
        #for word, tokens, probs, pe in token_list[start_index:start_index + len_sentence_list]: 
        #    matched_probs.append((word, tokens, probs, pe))
            #sentence_list.pop(0)
        #    sentence_list = sentence_list[len("".join(tokens)):]
        #    if sentence_list == []:
        #        break

        for tokens, probs, pe in token_list[start_index:start_index + len_sentence_list]: 
            matched_probs.append((tokens, probs, pe))
            #sentence_list.pop(0)
            sentence_list = sentence_list[len("".join(tokens)):]
            if sentence_list == []:
                break
        return start_index + len(matched_probs), matched_probs
    
    def match_tokens_phi(self, start_index, token_list, sentence):
        sentence_list = "".join(sentence)
        len_sentence_list = len(sentence_list)
        matched_probs = []
        dot = "."
        count = sentence_list.count(dot)
            
        for tokens, probs, pe in token_list[start_index:start_index + len_sentence_list]: 
            matched_probs.append((tokens, probs, pe))
            if dot in tokens:
                count -= 1
            if count == 0:
                break
        return start_index + len(matched_probs), matched_probs
    
    def merge_with_probs(self):
        df_gen_evidence = pd.read_pickle(self.gen_evidence_file)
        df_gen_evidence.drop_duplicates(subset="gen_evidence", inplace=True)
        df_sentence = pd.read_pickle(self.sentence_file)
        df_sentence = pd.merge(df_sentence, df_gen_evidence, on="gen_evidence", how="left")
        df_sentence.rename(columns={"docs_x": "docs"}, inplace=True)
        print(df_sentence['gen_evidence'].isna().sum())
        return df_sentence
    
    def concatenate_tokens_with_probs(self, tokens_probs):
        concatenated = []
        current_word = ""
        current_tokens = []  
        current_probs = []  
        current_pe = []
        for token, prob, pe in tokens_probs:
            if token.startswith(" "):  
                if current_word:
                    concatenated.append((current_word, current_tokens, current_probs, current_pe)) 
                current_word = token.strip() 
                current_tokens = [token.strip()]  
                current_probs = [prob]  
                current_pe = [pe]
            else:
                current_word += token  
                current_tokens.append(token)  
                current_probs.append(prob)  
                current_pe.append(pe)

        if current_word:
            concatenated.append((current_word, current_tokens, current_probs, current_pe))
        return concatenated
    
    def create_probs(self, df_sentence):
        df_grouped = df_sentence.groupby("gen_evidence")
        df_probs_sentences = pd.DataFrame(columns=["output_sentence", "label_sentence", "gen_evidence", "concat_probs_sentence", "docs"])

        index = 0
        for gen_evidence, group in df_grouped:
            output_list = group['output_sentence'].tolist()
            
            #concat_probs = group['token_details'].tolist()
            if group['concat_probs'].isna().sum() > 0:
                continue

            concat_probs = [[item['token'], item['generated_token_prob_from_transition_score'], item['token_wise_predictive_entropy']] for item in group['concat_probs'].tolist()[0]]
            
            #concat_probs = concat_probs[4:-6]
            #concat_probs = concat_probs[4:-8]
            #concat_probs = [tuple(item) for item in concat_probs]

            concat_probs = [tuple(d) for d in concat_probs]
            concat_probs = concat_probs[5:-8]
            #concat_probs = self.concatenate_tokens_with_probs(concat_probs)
            

            label_list = group['label_sentence'].tolist()
            docs = group['docs'].tolist()[0]
            claim = group['claim'].tolist()[0]
            sentences = self.nlp_processor.convert_text_to_sentences(gen_evidence)
            for output, label in zip(output_list, label_list):
                if len(sentences) > 1:
                    matched_probs_list = []
                    cosine_scores = []
                    start_index = 0
                    for sentence in sentences:
                        start_index, matched_probs = self.match_tokens_phi(start_index, concat_probs, sentence)
                        matched_probs_list.append(matched_probs)
                        tfidf_matrix = self.vectorizer.fit_transform([output, sentence])
                        cosine_score = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])
                        cosine_scores.append(cosine_score)

                    index_cosine = cosine_scores.index(max(cosine_scores))
                    concat_probs_sentence = matched_probs_list[index_cosine]
                else:
                    concat_probs_sentence = concat_probs
                
                df_probs_sentences = pd.concat([df_probs_sentences, pd.DataFrame({"output_sentence": [output], "label_sentence" : [label], "gen_evidence" : [gen_evidence], "concat_probs_sentence" : [concat_probs_sentence], "docs" : [docs], 'claim' : [claim]})], axis=0, ignore_index=True)
            index += 1

        df_probs_sentences.reset_index(drop=True, inplace=True)
        return df_probs_sentences
    

    def get_token_importance(self, df_probs_sentences, batch_size=32):
        df_token_importance = df_probs_sentences.copy()
        df_token_importance['token_importance'] = None
        df_token_importance['token_importance'] = df_token_importance['token_importance'].astype(object)

        for sample_idx, row in df_probs_sentences.iterrows():
            generated_text = row['output_sentence']
            concat_probs_sentence = row['concat_probs_sentence']
            tokens = [tokens[1] for tokens in concat_probs_sentence]
            tokens = [item for sublist in tokens for item in sublist]

            token_importance = []
            replaced_sentences = [
                generated_text.replace(token, '') for token in tokens
            ]

            # Batch predictions
            for i in range(0, len(replaced_sentences), batch_size):
                batch = replaced_sentences[i:i + batch_size]
                similarity_to_original = self.nlp_processor.roberta_measure_model.predict(
                    [(generated_text, sentence) for sentence in batch]
                )

                # Calculate token importance
                batch_importance = [1 - torch.tensor(similarity) for similarity in similarity_to_original]
                token_importance.extend(batch_importance)

            token_importance = torch.tensor(token_importance).reshape(-1)
            df_token_importance.at[sample_idx, 'token_importance'] = token_importance.tolist()

        df_token_importance.reset_index(drop=True, inplace=True)
        return df_token_importance
    
    
    """
    def get_token_importance(self, df_probs_sentences):
        df_token_importance = df_probs_sentences.copy()
        df_token_importance['token_importance'] = None
        df_token_importance['token_importance'] = df_token_importance['token_importance'].astype(object)

        for sample_idx, row in df_probs_sentences.iterrows():
            generated_text = row['output_sentence']
            concat_probs_sentence = row['concat_probs_sentence']
            tokens = [tokens[1] for tokens in concat_probs_sentence]
            tokens = [item for sublist in tokens for item in sublist]
            token_importance = []
            for token in tokens:
                similarity_to_original = self.nlp_processor.roberta_measure_model.predict([generated_text,
                                                                generated_text.replace(
                                                                    token,
                                                                    '')])
                token_importance.append(1 - torch.tensor(similarity_to_original))

            token_importance = torch.tensor(token_importance).reshape(-1)
            df_token_importance.at[sample_idx, 'token_importance'] = token_importance.tolist()
        df_token_importance.reset_index(drop=True, inplace=True)
        return df_token_importance
    """

In [None]:
class NLP:
    def __init__(self, roberta_model_path, roberta_measure_model):
        self.nlp = spacy.load("en_core_web_sm")
        self.nlp.add_pipe(
            "fastcoref",
            config={
                'model_architecture': 'LingMessCoref',
                'model_path': 'biu-nlp/lingmess-coref',
                'device': 'cuda'
            }
        )

        self.roberta_model_path = roberta_model_path
        self.roberta_measure_model = roberta_measure_model


    def convert_text_to_sentences(self, text):
        doc = self.nlp(text)
        sentences = [sent.text for sent in doc.sents]
        return sentences

In [None]:
if __name__ == "__main__":
    roberta_model_path = r"D:\huggingface\huggingface\hub\models--cross-encoder--stsb-roberta-large\snapshots\9e35bf01ec28b309411c8903d0d4165567303eb4"
    roberta_measure_model = CrossEncoder(model_name=roberta_model_path, num_labels=1)
    vectorizer = TfidfVectorizer()
    nlp_processor = NLP(roberta_model_path, roberta_measure_model)



In [None]:
dataset = "fever"
dataset_builder = DatasetBuilder(nlp_processor, vectorizer, f"test_evidence_with_claims/test_{dataset}_phi.pkl", f"datasets_{dataset}_phi/sentence_with_bart.pkl")
df_sentence = dataset_builder.merge_with_probs()
#df_sentence = pd.read_pickle(f"df_new_{dataset}_probs.pkl")
#df_sentence['token_details'].dropna(inplace=True)
df_probs_sentences = dataset_builder.create_probs(df_sentence)

#df_probs_sentences.to_pickle(f"probs_test_phi/probs_sentence_{dataset}.pkl")

#df_token_importance = dataset_builder.get_token_importance(df_probs_sentences)
#df_token_importance.to_pickle(f"probs_test_phi/df_new_{dataset}_probs_sentence_with_token_importance.pkl")

In [None]:
import pandas as pd

dataset = "fever"
#

df_probs_sentences = pd.read_pickle(f"probs_test_phi/probs_sentence_{dataset}_with_token_importance.pkl")

In [None]:
df_probs_sentences

In [None]:
for index, row in df_probs_sentences.iterrows():
    concat_probs_sentence = row['concat_probs_sentence']
    tokens = [tokens[1] for tokens in concat_probs_sentence]
    #tokens = [item for sublist in tokens for item in sublist]
    print(tokens)
    print("###")