In [1]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np

from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import accuracy_score
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score, f1_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ProbeNN(nn.Module):
    def __init__(self, input_dim):
        super(ProbeNN, self).__init__()
        self.layer1 = nn.Linear(input_dim, 256)
        self.layer2 = nn.Linear(256, 128)
        self.layer3 = nn.Linear(128, 64)
        self.output = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = torch.relu(self.layer2(x))
        x = torch.relu(self.layer3(x))
        x = torch.sigmoid(self.output(x))
        return x



def evaluate_unbalanced_testset(test_pred_probs, df_test, accuracy_threshold):
    df_pred = df_test.copy()
    df_pred['pred_prob'] = test_pred_probs
    df_pred['binary_pred'] = df_pred['pred_prob'].apply(lambda x: 1 if x > 0.5 else 0)

    # Initialize counters in a dictionary
    collect_counts = {'collect1s': 0, 'collect0s': 0}

    # Define the evaluation function for each group
    def evaluate_group(group, counts):
        if group['label_mini_fact'].sum() == len(group):  # All labels are 1
            counts['collect1s'] += 1
            if group['binary_pred'].sum() == len(group):  # All predictions must be 1
                return 'correct'
            else:
                return 'incorrect'
        else:  # At least one label is 0
            counts['collect0s'] += 1
            if (group['binary_pred'] == 0).any():  # At least one prediction must be 0
                return 'correct'
            else:
                return 'incorrect'

    # Apply the evaluation function to each group
    grouped_predictions = df_pred.groupby('gen_sentence').apply(lambda grp: evaluate_group(grp, collect_counts)).reset_index(name='group_prediction')
    
    num_correct = (grouped_predictions['group_prediction'] == 'correct').sum()
    accuracy = num_correct / len(grouped_predictions)

    # Calculate the AUC-ROC score
    df_grouped = df_pred.groupby('gen_sentence').agg(
        true_group_label=('label_mini_fact', lambda x: 1 if x.sum() == len(x) else 0),
        pred_group_prob=('pred_prob', 'min')
    ).reset_index()
    
    auc_roc_score = roc_auc_score(df_grouped['true_group_label'], df_grouped['pred_group_prob'])
    return accuracy, collect_counts['collect1s'], collect_counts['collect0s'], auc_roc_score



def compute_roc_curve(test_labels, test_pred_prob):
    fpr, tpr, _ = roc_curve(test_labels, test_pred_prob)
    roc_auc = auc(fpr, tpr)
    return roc_auc, fpr, tpr 


def get_results(df_test, model_path, layer, probe_method, accuracy_threshold):
    test_embeddings = np.array(df_test[f'embeddings{layer}_{probe_method}'].tolist())
    test_labels = df_test[f'label_{probe_method}']

    model = ProbeNN(test_embeddings.shape[1]).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    with torch.no_grad():
        test_pred_prob = model(torch.tensor(test_embeddings, dtype=torch.float32).to(device)).cpu().numpy()

    test_pred = (test_pred_prob > accuracy_threshold).astype(int)
    
    # Calculate accuracy
    test_accuracy = accuracy_score(test_labels, test_pred)
    
    # Calculate ROC-AUC
    roc_auc = roc_auc_score(test_labels, test_pred_prob)
    
    # Calculate F1 score for positive (label 1) and negative (label 0) classes
    f1_score_positive = f1_score(test_labels, test_pred, pos_label=1)
    f1_score_negative = f1_score(test_labels, test_pred, pos_label=0)
    
    return test_pred_prob, test_accuracy, roc_auc, f1_score_positive, f1_score_negative


In [None]:
results_fever_test = []
results_hover_test = []

train_datasets = ["hover"]
test_datasets = ["hover"]

layer = -16
probe = "with_train_popularity_balanced"
test_real_samples = False

df_test_mini_fact1 = pd.read_pickle(f"./test/test_llm_generations.pkl")

if test_real_samples:
    # we only use the LLMs negative samples
    df_test_mini_fact1 = df_test_mini_fact1[df_test_mini_fact1['label_mini_fact'] == 0]
    # here are only positive real samples
    df_test_mini_fact2 = pd.read_pickle(f"./test/test_all_popularity_real_samples.pkl")
    df_test_mini_fact2 = df_test_mini_fact2.sample(frac=1).reset_index(drop=True)
    df_test_mini_fact = pd.concat([df_test_mini_fact1, df_test_mini_fact2], axis=0)
else:
    df_test_mini_fact = df_test_mini_fact1

model_path_mini_fact = f"./probes/{probe}.pth"


test_pred_probs_mini_fact, test_accuracy_mini_fact, roc_auc_mini_fact, f1_score_positive, f1_score_negative = get_results(df_test_mini_fact, 
                                                                                        model_path_mini_fact, 
                                                                                         layer=layer, 
                                                                                         probe_method="mini_fact", accuracy_threshold=0.5)

df_test_mini_fact['pred_prob'] = test_pred_probs_mini_fact
df_test_mini_fact['binary_pred'] = df_test_mini_fact['pred_prob'].apply(lambda x: 1 if x > 0.5 else 0)

if test_real_samples:
    df_test_mini_fact.to_pickle(f"predictions/prediction_with_real_samples_{probe}.pkl")
else:
    df_test_mini_fact.to_pickle(f"predictions/prediction_{probe}.pkl")
