Notebook to evaluate the probes with AUROC and Accuracy

In [None]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np

from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import accuracy_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ProbeNN(nn.Module):
    def __init__(self, input_dim):
        super(ProbeNN, self).__init__()
        self.layer1 = nn.Linear(input_dim, 256)
        self.layer2 = nn.Linear(256, 128)
        self.layer3 = nn.Linear(128, 64)
        self.output = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = torch.relu(self.layer2(x))
        x = torch.relu(self.layer3(x))
        x = torch.sigmoid(self.output(x))
        return x



def evaluate_unbalanced_testset(test_pred_probs, df_test, accuracy_threshold):
    df_pred = df_test.copy()
    df_pred['pred_prob'] = test_pred_probs
    df_pred['binary_pred'] = df_pred['pred_prob'].apply(lambda x: 1 if x > 0.5 else 0)

    # Initialize counters in a dictionary
    collect_counts = {'collect1s': 0, 'collect0s': 0}

    # Define the evaluation function for each group
    def evaluate_group(group, counts):
        if group['label_mini_fact'].sum() == len(group):  # All labels are 1
            counts['collect1s'] += 1
            if group['binary_pred'].sum() == len(group):  # All predictions must be 1
                return 'correct'
            else:
                return 'incorrect'
        else:  # At least one label is 0
            counts['collect0s'] += 1
            if (group['binary_pred'] == 0).any():  # At least one prediction must be 0
                return 'correct'
            else:
                return 'incorrect'

    grouped_predictions = df_pred.groupby('gen_sentence').apply(lambda grp: evaluate_group(grp, collect_counts)).reset_index(name='group_prediction')
    
    num_correct = (grouped_predictions['group_prediction'] == 'correct').sum()
    accuracy = num_correct / len(grouped_predictions)

    # Calculate AUROC
    df_grouped = df_pred.groupby('gen_sentence').agg(
        true_group_label=('label_mini_fact', lambda x: 1 if x.sum() == len(x) else 0),
        pred_group_prob=('pred_prob', 'min')
    ).reset_index()
    
    auc_roc_score = roc_auc_score(df_grouped['true_group_label'], df_grouped['pred_group_prob'])
    return accuracy, collect_counts['collect1s'], collect_counts['collect0s'], auc_roc_score


def compute_roc_curve(test_labels, test_pred_prob):
    fpr, tpr, _ = roc_curve(test_labels, test_pred_prob)
    roc_auc = auc(fpr, tpr)
    return roc_auc, fpr, tpr 


def get_results(df_test, model_path, layer, probe_method, accuracy_threshold):
    test_embeddings = np.array(df_test[f'embeddings{layer}_{probe_method}'].tolist())
    test_labels = df_test[f'label_{probe_method}']

    model = ProbeNN(test_embeddings.shape[1]).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    with torch.no_grad():
        test_pred_prob = model(torch.tensor(test_embeddings, dtype=torch.float32).to(device)).cpu().numpy()

    test_accuracy = accuracy_score(test_labels, test_pred_prob > accuracy_threshold)
    roc_auc, fpr, tpr = compute_roc_curve(test_labels, test_pred_prob)
    return test_pred_prob, test_accuracy, roc_auc


In [19]:
results_fever_test = []
results_hover_test = []


train_datasets = ["hover"]
test_datasets = ["hover"]

model_name = "llama"
layers = [-1]
balanced = False

#caption = "\\textbf{AUROC} across different layers for FEVER and HoVer test datasets"
#caption = "\\textbf{AUROC} scores across different layers for cross-testing: probes trained on HoVer are tested on FEVER, and probes trained on FEVER are tested on HoVer."
#caption = "\\textbf{AUROC} scores for probes trained on Llama3-8B embeddings and tested on Phi generations"

#label = "tab:auroc_probes"
#label = "tab:auroc_probes_cross_test" 
#label = "tab:auroc_probes_phi"



def balance_dataframe(df, label_name):
    df_label_1 = df[df[str(label_name)] == 1]
    df_label_0 = df[df[str(label_name)] == 0]
    min_class_count = min(len(df_label_1), len(df_label_0))
    df_label_1_downsampled = df_label_1.sample(min_class_count, random_state=42)
    df_label_0_downsampled = df_label_0.sample(min_class_count, random_state=42)
    balanced_df = pd.concat([df_label_1_downsampled, df_label_0_downsampled])
    return balanced_df.reset_index(drop=True)

for test_dataset, train_dataset in zip(test_datasets, train_datasets):
    for layer in layers:
        print(f"Layer: {layer}")
        layer_name = (
            "First \\newline Hidden \\newline Layer" if layer == 1 else
            "Hidden \\newline Layer: \\newline 9" if layer == -24 else
            "Hidden \\newline Layer: \\newline 17" if layer == -16 else
            "Hidden \\newline Layer: \\newline 25" if layer == -8 else
            "Last \\newline Hidden \\newline Layer" if layer == -1 else None
        )
        if model_name == "llama":
            test_dataset_name = f"processed_datasets_{model_name}_{test_dataset}_layer{layer}"
            if balanced:
                df_test_mini_fact = pd.read_pickle(f"./{test_dataset_name}/mini_fact_{test_dataset}_test_balanced.pkl")
                df_test_sentence = pd.read_pickle(f"./{test_dataset_name}/sentence_{test_dataset}_test_balanced.pkl")
            else:
                df_test_mini_fact = pd.read_pickle(f"./{test_dataset_name}/mini_fact_{test_dataset}_test_unbalanced.pkl")
                df_test_sentence = pd.read_pickle(f"./{test_dataset_name}/sentence_{test_dataset}_test_unbalanced.pkl")
                df_test_sentence = balance_dataframe(df_test_sentence, 'label_sentence')
                df_test_mini_fact = df_test_mini_fact[df_test_mini_fact['gen_sentence'].isin(df_test_sentence['output_sentence'])]


        if model_name == "phi":
            test_dataset_name = f"processed_datasets_{model_name}"
            df_test_mini_fact = pd.read_pickle(f"./{test_dataset_name}/mini_fact_{test_dataset}_layer{layer}_test_unbalanced.pkl")
            df_test_sentence = pd.read_pickle(f"./{test_dataset_name}/sentence_{test_dataset}_layer{layer}_test_unbalanced.pkl")
            df_test_sentence = balance_dataframe(df_test_sentence, 'label_sentence')
            df_test_mini_fact = df_test_mini_fact[df_test_mini_fact['gen_sentence'].isin(df_test_sentence['output_sentence'])]
            
        model_path_sentence = f"./probes/sentence_embeddings{layer}_{train_dataset}.pth"
        model_path_mini_fact = f"./probes/mini_fact_embeddings{layer}_{train_dataset}.pth"



        test_pred_probs_sentences, test_accuracy_sentences, roc_auc_sentences = get_results(df_test_sentence, 
                                                                                            model_path_sentence, 
                                                                                            layer=layer, 
                                                                                            probe_method="sentence", accuracy_threshold=0.5)
        
        test_pred_probs_mini_fact, test_accuracy_mini_fact, roc_auc_mini_fact = get_results(df_test_mini_fact, 
                                                                                                    model_path_mini_fact, 
                                                                                                    layer=layer, 
                                                                                                    probe_method="mini_fact", accuracy_threshold=0.5)
        
        df_test_mini_fact['pred_prob'] = test_pred_probs_mini_fact
        df_test_mini_fact['binary_pred'] = df_test_mini_fact['pred_prob'].apply(lambda x: 1 if x > 0.5 else 0)
        df_test_mini_fact.to_pickle(f"prediction.pkl")

        if not balanced:
            correct_ratio, collect1s, collect0s, auc_roc_score_mini_facts = evaluate_unbalanced_testset(test_pred_probs_mini_fact, df_test_mini_fact, accuracy_threshold=0.5)

        if test_dataset == "fever":
            print(f"Test dataset: {test_dataset}")
            results_fever_test.append({
                "train_dataset": train_dataset,
                "test_dataset": test_dataset,
                "layer": layer_name,
                "accuracySentence": test_accuracy_sentences if balanced else None,
                "auc_sentences": roc_auc_sentences,
                "accuracyMiniFacts": test_accuracy_mini_fact if balanced else None,
                "auc_mini_facts": roc_auc_mini_fact,
                "auc_mini_facts_sentences_match": auc_roc_score_mini_facts if not balanced else None
            })
        elif test_dataset == "hover":
            print(f"Test dataset: {test_dataset}")
            results_hover_test.append({
                "train_dataset": train_dataset,
                "test_dataset": test_dataset,
                "layer": layer_name,
                "accuracySentence": test_accuracy_sentences if balanced else None,
                "auc_sentences": roc_auc_sentences,
                "accuracyMiniFacts": test_accuracy_mini_fact if balanced else None,
                "auc_mini_facts": roc_auc_mini_fact,
                "auc_mini_facts_sentences_match": auc_roc_score_mini_facts if not balanced else None
            })

Layer: -1
Test dataset: hover


In [20]:
results_hover_test

[{'train_dataset': 'hover',
  'test_dataset': 'hover',
  'layer': 'Last \\newline Hidden \\newline Layer',
  'accuracySentence': None,
  'auc_sentences': 0.7846093400217464,
  'accuracyMiniFacts': None,
  'auc_mini_facts': 0.8162134434837778,
  'auc_mini_facts_sentences_match': 0.8105132352000586}]

Tests

In [None]:
grouped_mini_fact = df_test_mini_fact.groupby('gen_sentence')

for name, group in grouped_mini_fact:
    print(name)
    print(group['output_mini_fact'].tolist())
    print(group['label_mini_fact'].tolist())
    print("Probs: ", group['pred_prob'].tolist())
    sentence_label = df_test_sentence.loc[df_test_sentence['output_sentence'] == name, 'label_sentence'].values[0]
    print(sentence_label)
    print("\n")

In [None]:
test_pred_probs_mini_fact, test_accuracy_mini_fact, roc_auc_mini_fact = get_results(df_test_mini_fact, 
                                                                                    model_path_mini_fact, 
                                                                                    layer=layer, 
                                                                                    probe_method="mini_fact", accuracy_threshold=0.5)

df_test_mini_fact['pred_prob'] = test_pred_probs_mini_fact
df_test_mini_fact['binary_pred'] = df_test_mini_fact['pred_prob'].apply(lambda x: 1 if x > 0.5 else 0)
true_count = 0

for name, group in df_test_mini_fact.groupby('gen_sentence'):
    print(name)
    print(group['label_mini_fact'].values)

    
    if 0 in group['label_mini_fact'].values:
        if 0 in group['binary_pred'].values:
            print("True")
            true_count += 1
        else:
            print("False")
    else:
        if 0 in group['binary_pred'].values:
            print("False")
        else:
            print("True")
            true_count += 1
    
    gen_sentence = df_test_sentence[df_test_sentence['output_sentence'] == name]['output_sentence'].values[0]
    label_sentence = df_test_sentence[df_test_sentence['output_sentence'] == name]['label_sentence'].values[0]
    print(gen_sentence)
    print(label_sentence)
    print("###")

true_count / len(df_test_mini_fact['gen_sentence'].unique())