# Separation between true/false statements

In [None]:
import numpy as np
import os
import pandas as pd
from tqdm import tqdm

import torch
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, AutoConfig,
    QuantoConfig
)

from utils import *

# MODEL_PATH = "/Model/meta-llama/Llama-2-7b-hf"
# MODEL_PATH = "/Model/meta-llama/Llama-2-7b-chat-hf"
# MODEL_PATH = "/Model/meta-llama/Llama-2-13b-hf"
# MODEL_PATH = "/Model/meta-llama/Llama-2-13b-chat-hf"
# MODEL_PATH = "/Model/meta-llama/Meta-Llama-3.1-8B-hf"
# MODEL_PATH = "/Model/meta-llama/Meta-Llama-3.1-8B-Instruct-hf"
# MODEL_PATH = "/Model/meta-llama/Meta-Llama-3.1-70B-hf"
# MODEL_PATH = "/Model/meta-llama/Meta-Llama-3.1-70B-Instruct-hf"

# MODEL_PATH = "/Model/mistralai/Mistral-7B-v0.1"
MODEL_PATH = "/Model/mistralai/Mistral-7B-Instruct-v0.1"
# MODEL_PATH = "/Model/mistralai/Mistral-Large-Instruct-2407"

model_name = os.path.basename(MODEL_PATH)

config = AutoConfig.from_pretrained(MODEL_PATH)
# tok = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
# if tok.pad_token is None:
#     tok.pad_token = tok.eos_token
# tok.padding_side = 'left'

# quant_config = QuantoConfig(weights='float8')
# model = AutoModelForCausalLM.from_pretrained(
#     MODEL_PATH,
#     # torch_dtype=torch.bfloat16,
#     device_map='auto',
#     attn_implementation="eager",
#     quantization_config=quant_config
# )


In [None]:
dataset_dir = "data/"
topics = [
    "animal_class",
    "cities",
    "facts",
    "sp_en_trans",
    "inventors",
    "element_symb",
]
dataset_names = []
for t in topics:
    dataset_names.append(t)
    dataset_names.append("neg_"+t)
    dataset_names.append(t+"_conj")
    dataset_names.append(t+"_disj")

prompt_option = "no_prompt"

In [None]:
acts_dir = f"activations_and_labels/{model_name}/{prompt_option}"
os.makedirs(acts_dir, exist_ok=True)
activations_dict = {}
labels_dict = {}
for k in dataset_names:
    print(k)
    task_dir = os.path.join(acts_dir, k)
    acts = []
    for layer in range(config.num_hidden_layers):
        acts.append(np.load(os.path.join(task_dir, f"acts_{layer}.npy")))
    labels = np.array(pd.read_csv(f"data/{k}.csv")['label'].tolist())
    activations_dict[k] = np.array(acts)
    labels_dict[k] = np.array(labels)


In [None]:
import matplotlib.pyplot as plt
import seaborn
import pandas as pd

seaborn.set_style('whitegrid')

fig = plt.figure(figsize=(8, 4))
relative_variances_all = 0
data_df = {
    "data": [],
    "index": [],
    "group": [],
}
for k in activations_dict:
    activations_by_layer = activations_dict[k]
    labels = labels_dict[k]

    between_class_variances = []
    within_class_variances = []
    for layer_nr in range(config.num_hidden_layers):
        # Calculate means for each class
        false_stmnt_ids = labels == 0
        true_stmnt_ids = labels == 1

        false_acts = activations_by_layer[layer_nr, false_stmnt_ids]
        true_acts = activations_by_layer[layer_nr, true_stmnt_ids]

        mean_false = false_acts.mean(axis=0)
        mean_true = true_acts.mean(axis=0)

        # Calculate within-class variance
        within_class_variance_false = np.var(false_acts, axis=0).mean()
        within_class_variance_true = np.var(true_acts, axis=0).mean()
        within_class_variances.append((within_class_variance_false + within_class_variance_true) / 2)

        # Calculate between-class variance
        overall_mean = activations_by_layer[layer_nr].mean(axis=0)
        between_class_variances.append(((mean_false - overall_mean)**2
                                        + (mean_true - overall_mean)**2).mean().item() / 2)

    relative_variances = np.array(between_class_variances) / np.array(within_class_variances)
    # ax = plt.plot(range(len(relative_variances)), relative_variances, label=k)
    # plt.annotate(relative_variances.argmax(), xy=(relative_variances.argmax(), relative_variances.max()), c=ax[0].get_color(), fontsize=12)
    relative_variances_all += relative_variances

    for layer_nr in range(config.num_hidden_layers):
        data_df['data'].append(relative_variances[layer_nr])
        data_df['index'].append(layer_nr)
        data_df['group'].append(k.replace("neg_","").replace("_conj","").replace("_disj",""))
        # if "neg" in k:
        #     data_df['group'].append('neg')
        # elif "conj" in k:
        #     data_df['group'].append('conj')
        # elif "disj" in k:
        #     data_df['group'].append('disj')
        # else:
        #     data_df['group'].append('affirm')
# plt.yscale('log')
# plt.legend()
# plt.xlabel("Layers (starting from 0)", fontsize=14)
# plt.ylabel("Ratio", fontsize=20)
# plt.xticks(fontsize=20)
# plt.yticks(fontsize=20)
# plt.title(model_name.replace("-hf","").replace("Meta-","").replace("b",'B'), fontsize=20)

data_df = pd.DataFrame(data_df)
draw_df = data_df.reset_index().melt(id_vars=['index', 'group'], var_name='col')
# seaborn.lineplot(draw_df, x='index', y='value', hue='group', legend=False)
ax = seaborn.lineplot(data_df, x='index', y='data', hue='group', errorbar=('se'))

for i, topic in enumerate(topics):
    topic_data = data_df.loc[data_df['group'] == topic]
    line = [0 for _ in range(config.num_hidden_layers)]
    for _, rec in topic_data.iterrows():
        line[rec['index']] += rec['data']
    line = np.array(line)/4
    peak_index = np.argmax(line)
    plt.annotate(peak_index, xy=(peak_index, line[peak_index]), c=ax.lines[i].get_color(), fontsize=14)

ax.legend_.set_title('')
ax.set_xlabel("Layer", fontsize=14)
ax.set_ylabel("Ratio", fontsize=14)
ax.tick_params(labelsize=14)
plt.setp(ax.get_legend().get_texts(), fontsize=14)

variance_dir = "figures_relative_variance/"
os.makedirs(variance_dir, exist_ok=True)
figname = os.path.join(variance_dir, f"{model_name}_{prompt_option}.pdf")
if os.path.exists(figname):
    os.remove(figname)
plt.tight_layout()
plt.savefig(figname)

print(relative_variances_all.argmax())


# Probe performance by layer

In [None]:
import os
import json
import matplotlib.pyplot as plt
import seaborn
seaborn.set_style('whitegrid')


In [None]:
results_dir = "results/"

# model_name = "Mistral-Large-Instruct-2407"
# model_name = "Meta-Llama-3.1-70B-Instruct-hf"
model_name = "Meta-Llama-3.1-70B-hf"
# model_name = "Meta-Llama-3.1-8B-Instruct-hf"
# model_name = "Llama-2-7b-chat-hf"

n_layers = 80
prompt_option = "no_prompt"
topics = [
    "cities",
    # "inventors",
    "sp_en_trans",
    # "element_symb",
    # "element_symb_new",
    # "companies",
    # "facts",
]
results = {}
for topic in topics:
    result_file = os.path.join(results_dir, f"{model_name}_{topic}_{prompt_option}.jsonl")
    results_topic = []
    with open(result_file) as fp:
        for line in fp:
            results_topic.append(json.loads(line))
    results[topic] = results_topic


In [None]:
plt.figure(figsize=(12,6))
for topic in topics:
    results_topic = results[topic]
    plt.plot([rec['cosine_sim'] for rec in results_topic], label=topic, marker='.')

plt.legend()
plt.plot([0,n_layers], [0,0], 'k--')
plt.xticks(range(0, n_layers, 10))
plt.xlabel("Layers (starting from 0)")
plt.ylabel("Cosine similarity between\naffirmative/negated truthful direction")


In [None]:
layers_range = [0, n_layers]
sep = 10

fig, axes = plt.subplots(1, 3, figsize=(21,4))
probe_names = ('lr', 'mlp', 'mm')
for i, probe_name in enumerate(probe_names):
    for topic in topics:
        results_topic = results[topic]
        aurocs = []
        for rec in results_topic:
            for probe in rec['probes']:
                if probe['probe'] == probe_name:
                    aurocs.append(probe['auroc'])
        axes[i].plot(aurocs[layers_range[0]:layers_range[1]], label=topic, marker='.')
    axes[i].legend()
    axes[i].plot(layers_range, [0.5,0.5], 'k--')
    axes[i].set_xticks(range(layers_range[0], layers_range[1]+1, sep))
    axes[i].set_xlabel("Layers (starting from 0)")
    axes[i].set_ylabel("AUROC")
    axes[i].set_title(probe_name)


In [None]:
layers_range = [0, n_layers]

fig, axes = plt.subplots(1, 4, figsize=(19,4))
probe_names = ('lr', 'mlp', 'mm', 'svm')
for i, probe_name in enumerate(probe_names):
    for topic in topics:
        results_topic = results[topic]
        accs = []
        for rec in results_topic:
            for probe in rec['probes']:
                if probe['probe'] == probe_name:
                    if probe['best_acc'] is not None:
                        accs.append(probe['best_acc'])
                    else:
                        accs.append(probe['acc'])
        axes[i].plot(range(layers_range[0], layers_range[1]), accs[layers_range[0]:layers_range[1]], label=topic, marker='.')
    axes[i].legend()
    axes[i].plot(layers_range, [0.5,0.5], 'k--')
    axes[i].set_xticks(range(layers_range[0], layers_range[1]+1, 10))
    axes[i].set_xlabel("Layers (starting from 0)")
    axes[i].set_ylabel("Best acc")
    axes[i].set_title(probe_name)


# Generalization from affirmative statements to negative statements

In [None]:
import numpy as np
import os
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn

model_names = (
    "Llama-2-7b-hf",
    "Llama-2-7b-chat-hf",
    "Llama-2-13b-hf",
    "Llama-2-13b-chat-hf",
    "Meta-Llama-3.1-8B-hf",
    "Meta-Llama-3.1-8B-Instruct-hf",
    "Meta-Llama-3.1-70B-hf",
    "Meta-Llama-3.1-70B-Instruct-hf",
)
# layer_indices = (
#     13,
#     37,
# )

prompt_option = "no_prompt"
topics = (
    "animal_class",
    "cities",
    "inventors",
    "facts",
    "element_symb",
    # "companies",
    "sp_en_trans",
)
probe_names = ("lr", "mlp", "svm", "mm")

aurocs = []
# for probe in probe_names:
#     aurocs[probe] = {}
#     for model_name, layer_index in zip(model_names, layer_indices):
#         aurocs_model = {}
#         for topic in topics:
#             results = []
#             with open(os.path.join("results/", f"{model_name}_{topic}_{prompt_option}.jsonl")) as fp:
#                 for line in fp:
#                     results.append(json.loads(line))
#             results = results[layer_index]
#             aurocs_topic = []
#             for rec in results['probes']:
#                 if rec['probe'] == probe:
#                     aurocs_topic = rec['auroc']
#             aurocs_model[topic] = aurocs_topic
#         aurocs[probe][model_name.strip("-hf")] = aurocs_model

results_dir = "neg_generalization_results/"
seeds = [0,1,2]
for seed in seeds:
    aurocs_seed = {}
    for probe in probe_names:
        aurocs_seed[probe] = {}
        probe_results = pd.read_csv(os.path.join(results_dir, f"seed={seed}/{probe}.csv"))
        for _, rec in probe_results.iterrows():
            aurocs_model = {}
            for topic in topics:
                if not np.isnan(rec[topic]):
                    aurocs_model[topic] = rec[topic]
            model_name = rec['model']
            if model_name in model_names:
                # m = model_name.replace("-hf","").replace("Meta-","").replace('b','B').replace("-chat","-Chat")
                # m = m.replace("Llama", "L").replace("Chat", "C").replace("Instruct", "I")
                m = "M" + str(model_names.index(model_name))
                aurocs_seed[probe][m] = aurocs_model
    aurocs.append(aurocs_seed)

# aurocs_avg = {}
# for auroc_seed in aurocs:
#     for probe, auroc_model in auroc_seed.items():
#         aurocs_avg[probe] = {}
#         for model_name, auroc_topic in auroc_model.items():
#             aurocs_avg[probe][model_name] = {}
#             for topic, auroc in auroc_topic.items():
#                 if topic not in aurocs_avg[probe][model_name]:
#                     aurocs_avg[probe][model_name][topic] = 0
#                 aurocs_avg[probe][model_name][topic] += auroc
# for probe, auroc_model in aurocs_avg.items():
#     for model_name, auroc_topic in auroc_model.items():
#         for topic, auroc in auroc_topic.items():
#             # x = np.round(aurocs_avg[probe][model_name][topic], 2)
#             # if x == 0.0:
#             #     x = 0
#             # elif x == 1.0:
#             #     x = 1
#             x = np.round(aurocs_avg[probe][model_name][topic]*100, 0).astype(np.int32)
#             aurocs_avg[probe][model_name][topic] = x


In [None]:
import torch.utils._pytree as pytree_utils

aurocs_avg = pytree_utils.tree_map(lambda x, y, z: round((x + y + z)/3*100), aurocs[0], aurocs[1], aurocs[2])


In [None]:
TITLE_SIZE = 16
TICK_LABEL_SIZE = 14
ANNOT_SIZE = 18

fig, axes = plt.subplots(2, 2, figsize=(7, 10), gridspec_kw={'hspace': 0.02, 'wspace': 0.01})
cbar_ax = fig.add_axes([.93, 0.122, .02, .75])

for i, probe in enumerate(probe_names):
    ax = seaborn.heatmap(
        pd.DataFrame(aurocs_avg[probe]).T,
        annot=True,
        annot_kws={"fontsize":ANNOT_SIZE},
        vmin=0,
        vmax=100,
        square=True,
        cmap='coolwarm',
        fmt='d',
        ax=axes[i//2][i%2],
        cbar=(i==0),
        cbar_ax=cbar_ax,
    )
    ax.set_title(probe.upper(), fontsize=TITLE_SIZE)
    if i == 2:
        ax.set_xticklabels(ax.get_xticklabels(), fontsize=TICK_LABEL_SIZE, rotation=20, ha='right')
        ax.set_yticklabels(
            ax.get_yticklabels(),
            fontsize=TICK_LABEL_SIZE,
        )
        ax.tick_params(left=True, bottom=True)
    else:
        # ax.set_xlabel(None)
        # ax.set_ylabel(None)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.tick_params(left=False, bottom=False)
cbar_ax.tick_params(labelsize=TICK_LABEL_SIZE)
fig.tight_layout()
fig.subplots_adjust(wspace=0.01)

# figname = "figures_heatmap/neg_generalization.pdf"
# if os.path.exists(figname):
#     os.remove(figname)
# plt.savefig(figname, format='pdf', bbox_inches='tight')


In [None]:
TITLE_SIZE = 12
TICK_LABEL_SIZE = 12
ANNOT_SIZE = 10

fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 1.6), gridspec_kw={'wspace': 0.01})
cbar_ax = fig.add_axes([.91, 0.15, .01, .7])

for i, probe in enumerate(probe_names):
    ax = seaborn.heatmap(
        pd.DataFrame(aurocs_avg[probe]),
        annot=True,
        annot_kws={"fontsize":ANNOT_SIZE},
        vmin=0,
        vmax=100,
        # square=True,
        cmap='coolwarm',
        fmt='d',
        ax=axes[i],
        cbar=(i==0),
        cbar_ax=cbar_ax,
    )
    ax.set_title(probe.upper(), fontsize=TITLE_SIZE)
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=TICK_LABEL_SIZE, ha='center')
    if i == 0:
        ax.set_yticklabels(
            ax.get_yticklabels(),
            fontsize=TICK_LABEL_SIZE,
        )
        ax.tick_params(left=True, bottom=True)
    else:
        ax.set_yticklabels([])
        ax.tick_params(left=False)
cbar_ax.tick_params(labelsize=TICK_LABEL_SIZE)
fig.tight_layout()
fig.subplots_adjust(wspace=0.01)

figname = "figures_heatmap/neg_generalization.pdf"
if os.path.exists(figname):
    os.remove(figname)
plt.savefig(figname, format='pdf', bbox_inches='tight')


In [None]:
generalization_num = [0, 0, 4, 4, 4, 4, 5, 6]
models = [f"M{i}" for i in range(8)]
from math import pi
 
# Set data
df = pd.DataFrame({
'model': models,
'data': generalization_num,
})

N = len(models)
 
# We are going to plot the first line of the data frame.
# But we need to repeat the first value to close the circular graph:
values=generalization_num
 
# What will be the angle of each axis in the plot? (we divide the plot / number of variable)
angles = [n / float(N) * 2 * pi for n in range(N)]
# angles += angles[:1]
 
# Initialise the spider plot
ax = plt.subplot(111, polar=True)
 
# Draw one axe per variable + add labels
plt.xticks(angles, models)
 
# Draw ylabels
ax.set_rlabel_position(0)
plt.yticks(range(6), range(6), color="grey")
plt.ylim(0,6)
 
# Plot data
ax.plot(angles, values, linewidth=1, linestyle='solid')
 
# Fill area
ax.fill(angles, values, 'b', alpha=0.1)


In [None]:
angles

# Generalization to logical conjunction/disjunctions

In [None]:
import joblib
import numpy as np
import os
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, roc_curve
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn

from probes import *
from utils import *


model_names_layer_indices = {
    "Llama-2-7b-hf": 12,
    "Llama-2-7b-chat-hf": 13,
    "Llama-2-13b-hf": 13,
    "Llama-2-13b-chat-hf": 13,
    "Meta-Llama-3.1-8B-hf": 12,
    "Meta-Llama-3.1-8B-Instruct-hf": 13,
    "Meta-Llama-3.1-70B-hf": 33,
    "Meta-Llama-3.1-70B-Instruct-hf": 37,
    # "Mistral-Large-Instruct-2407": 43,
}

prompt_option = "no_prompt"

topics = (
    "animal_class",
    "cities",
    "inventors",
    "facts",
    "element_symb",
    "sp_en_trans",
)
probe_names = ('lr', 'mlp', 'svm', 'mm')
logical_transformation = "conj" # conj / disj

aurocs = []
for seed in (0,1,2):
    print(f"seed={seed}")
    aurocs_seed = {}
    for probe_name in probe_names:
        print(f"  probe={probe_name}")
        aurocs_seed[probe_name] = {}
        for model_name, layer_index in model_names_layer_indices.items():
            print(f"    model={model_name}")
            aurocs_seed[model_name] = {}
            activations_dir = f"activations_and_labels/{model_name}/{prompt_option}"
            probes_dir = f"probes/{model_name}/seed={seed}"
            probe = joblib.load(os.path.join(probes_dir, f"{probe_name}.joblib"))
            aurocs_seed[probe_name][model_name] = {}
            for topic in topics:
                print(f"      topic={topic}")
                aurocs_seed[probe_name][model_name][topic] = {}
                test_topic = topic + "_" + logical_transformation
                if logical_transformation == "disj" and topic != "facts":
                    test_topic += "_new"
                test_activations = np.load(f"{activations_dir}/{test_topic}/acts_{layer_index}.npy")
                test_labels = pd.read_csv(f"data/{test_topic}.csv")['label']
                test_labels = np.array(test_labels.tolist())
            
                pos_label_probs = probe.predict_proba(test_activations)[:, 1]
                auroc = roc_auc_score(test_labels, pos_label_probs)
                aurocs_seed[probe_name][model_name][topic] = auroc
    aurocs.append(aurocs_seed)


In [None]:
import torch.utils._pytree as pytree_utils

def tree_average(trees, scale=1, round_to_int=False):
    tree_sum = trees[0]
    for tree in trees[1:]:
        tree_sum = pytree_utils.tree_map(lambda x, y: x + y, tree_sum, tree)
    avg_fn = lambda x: round(x/len(trees)*scale) if round_to_int else x/len(trees)*scale
    tree_avg = pytree_utils.tree_map(avg_fn, tree_sum)
    return tree_avg

aurocs_avg = tree_average(aurocs, scale=100, round_to_int=True)


In [None]:
aurocs_avg_new = {}
for k, d in aurocs_avg.items():
    aurocs_avg_new[k] = {}
    for m, v in d.items():
        aurocs_avg_new[k][m.replace("-hf","").replace("Meta-","").replace('b','B').replace("-chat","-Chat")] = v
aurocs_avg = aurocs_avg_new


In [None]:
aurocs_avg_single_model = {}
target_model = 'Llama-3.1-8B'
for probe in probe_names:
    aurocs_avg_single_model[probe] = {}
    for topic in topics:
        aurocs_avg_single_model[probe][topic] = aurocs_avg[probe][target_model][topic]

TITLE_SIZE = 20
TICK_LABEL_SIZE = 20
ANNOT_SIZE = 20
fig, ax = plt.subplots(1,1,figsize=(5,5))
cbar_ax = fig.add_axes([.92, .25, .03, .5])
ax = seaborn.heatmap(
    pd.DataFrame(aurocs_avg_single_model).T,
    annot=True,
    annot_kws={"fontsize":ANNOT_SIZE},
    vmin=0,
    vmax=100,
    square=True,
    cmap='coolwarm',
    fmt='d',
    ax=ax,
    cbar_ax=cbar_ax,
)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=TICK_LABEL_SIZE, rotation=45, ha='right')
ax.set_yticklabels([x.get_text().upper() for x in ax.get_yticklabels()], fontsize=TICK_LABEL_SIZE, rotation=0)
cbar_ax.tick_params(labelsize=TICK_LABEL_SIZE)

# figname = f"figures_heatmap/{logical_transformation}_generalization.pdf"
# if os.path.exists(figname):
#     os.remove(figname)
# plt.savefig(figname, format='pdf', bbox_inches='tight')


In [None]:
TITLE_SIZE = 20
TICK_LABEL_SIZE = 20
ANNOT_SIZE = 20

fig, axes = plt.subplots(1, len(probe_names), figsize=(19, 7), sharey=True)
cbar_ax = fig.add_axes([1., .35, .01, .5])

for i, probe in enumerate(probe_names):
    ax = seaborn.heatmap(
        pd.DataFrame(aurocs_avg[probe]).T,
        annot=True,
        annot_kws={"fontsize":ANNOT_SIZE},
        vmin=0,
        vmax=100,
        square=True,
        cmap='coolwarm',
        fmt='d',
        ax=axes[i],
        cbar=(i==0),
        cbar_ax=cbar_ax,
    )
    ax.set_title(probe.upper(), fontsize=TITLE_SIZE)
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=TICK_LABEL_SIZE, rotation=45, ha='right')
    if i == 0:
        ax.set_yticklabels(ax.get_yticklabels(), fontsize=TICK_LABEL_SIZE)
    else:
        ax.tick_params(left=False)
cbar_ax.tick_params(labelsize=TICK_LABEL_SIZE)
fig.tight_layout()
fig.subplots_adjust(wspace=0.01)

figname = f"figures_heatmap/{logical_transformation}_generalization_full.pdf"
if os.path.exists(figname):
    os.remove(figname)
plt.savefig(figname, format='pdf', bbox_inches='tight')


# Generalization to QA (MMLU)

In [None]:
from itertools import chain
import joblib
import numpy as np
import os
from sklearn.calibration import calibration_curve
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, roc_curve, brier_score_loss
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn

from probes import *
from utils import *


model_names_layer_indices = {
    # "Llama-2-7b-hf": 12,
    # "Llama-2-7b-chat-hf": 13,
    # "Llama-2-13b-hf": 13,
    # "Llama-2-13b-chat-hf": 13,
    "Meta-Llama-3.1-8B-hf": 12,
    # "Meta-Llama-3.1-8B-Instruct-hf": 13,
    # "Meta-Llama-3.1-70B-hf": 33,
    # "Meta-Llama-3.1-70B-Instruct-hf": 37,
    # "Mistral-Large-Instruct-2407": 43,
}

prompt_option = "no_prompt"

# topics = (
#     "mmlu_true_false",
#     "mmlu_true_false_with_options",
#     "mmlu_true_false_with_options_TTTTT",
#     "mmlu_true_false_with_options_TTFFF",
# )
# settings = ['no opt.', 'w/ opt.', 'TTTTT', 'TTFFF']

topics = (
    "mmlu_true_false_mc",
    "mmlu_true_false_mc_TTTTT",
    "mmlu_true_false_mc_TTFFF",
)
settings = ['0-shot', 'TTTTT', 'TTFFF']

probe_names = ('lr', 'mlp', 'svm', 'mm')

results = {}
metrics = ['auroc', 'brier', 'ece']
for m in metrics:
    results[m] = []

for seed in (0,1,2):
    print(f"seed={seed}")
    for m in results:
        results[m].append({})
    
    for probe_name in probe_names:
        print(f"  probe={probe_name}")
        for m in results:
            results[m][-1][probe_name] = {}
        
        for model_name, layer_index in model_names_layer_indices.items():
            print(f"    model={model_name}")
            for m in results:
                results[m][-1][probe_name][model_name] = {}
            
            activations_dir = f"activations_and_labels/{model_name}/{prompt_option}"
            probes_dir = f"probes/{model_name}/seed={seed}"
            probe = joblib.load(os.path.join(probes_dir, f"{probe_name}.joblib"))
            for topic in topics:
                print(f"      topic={topic}")
                
                test_topic = topic
                test_activations = np.load(f"{activations_dir}/{test_topic}/acts_{layer_index}.npy")
                test_labels = pd.read_csv(f"data/{test_topic}.csv")['label']
                test_labels = np.array(test_labels.tolist())
                
                test_pos_label_indices = np.where(test_labels==1)[0].tolist()
                test_neg_label_indices = np.where(test_labels==0)[0].tolist()
                num_test_pos_labels = np.count_nonzero(test_labels==1)
                rng = np.random.RandomState(42)
                test_neg_label_indices_chosen = rng.choice(test_neg_label_indices, num_test_pos_labels)
                test_label_indices_chosen = np.concatenate([test_pos_label_indices, test_neg_label_indices_chosen])
                test_activations = test_activations[test_label_indices_chosen]
                test_labels = test_labels[test_label_indices_chosen]
            
                pos_label_probs = probe.predict_proba(test_activations)[:, 1]
                auroc = roc_auc_score(test_labels, pos_label_probs)
                results['auroc'][-1][probe_name][model_name][topic] = auroc
                
                brier =  brier_score_loss(test_labels, pos_label_probs)
                results['brier'][-1][probe_name][model_name][topic] = brier
                
                prob_true, prob_pred = calibration_curve(test_labels, pos_label_probs, n_bins=10, strategy='quantile')
                ece = calibration_error_expectation(prob_true, prob_pred)
                results['ece'][-1][probe_name][model_name][topic] = ece


In [None]:
data_df = {
    "probe": [],
    "topic": [],
    "auroc": [],
    "brier": [],
    "ece":   [],
}

target_model = 'Meta-Llama-3.1-8B-hf'

for index in range(len(results[metrics[0]])):
    for probe_name in probe_names:
        for topic in topics:
            data_df['probe'].append(probe_name)
            data_df['topic'].append(topic)
            for m in metrics:
                data_df[m].append(results[m][index][probe_name][target_model][topic])
data_df = pd.DataFrame(data_df)

## Open-ended answer

In [None]:
seaborn.set_theme('paper', 'white')

fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1, figsize=(9, 4), sharex=True, gridspec_kw={'hspace': 0.2}, height_ratios=(0.34, 0.33, 0.33))

# PALETTE = ["#e8ddcb","#cdb380","#036564","#033649","#031634"]
PALETTE = "coolwarm"

seaborn.barplot(data_df, x='probe', y='auroc', hue='topic', order=('lr', 'mlp', 'mm', 'svm'), ax=ax1, errorbar=('se'), capsize=0.2, palette=PALETTE)
seaborn.barplot(data_df, x='probe', y='ece',   hue='topic', order=('lr', 'mlp', 'mm', 'svm'), ax=ax2, errorbar=('se'), capsize=0.2, palette=PALETTE, legend=False)
seaborn.barplot(data_df, x='probe', y='brier', hue='topic', order=('lr', 'mlp', 'mm', 'svm'), ax=ax3, errorbar=('se'), capsize=0.2, palette=PALETTE, legend=False)

# ax1.axhline(0.5, ls='--', color='darkgray')
ax1.set_ylim(0.5)

ax2.set_ylim(0, 0.5)

for t, l in zip(ax1.legend_.texts, settings):
    t.set_text(l)
ax1.legend_.set_title('')
seaborn.move_legend(ax1, 'upper left', bbox_to_anchor=(0.12, 1.5), ncol=len(settings))

ax1.set_ylabel(r"AUROC$\uparrow$")
ax2.set_ylabel(r"ECE$\downarrow$")
ax3.set_ylabel(r"BS$\downarrow$")
ax3.set_xlabel("Probe")
ax3.set_xticklabels([x.get_text().upper() for x in ax3.get_xticklabels()])

ax1.tick_params(left=True)
ax2.tick_params(left=True)
ax3.tick_params(left=True)

ax1.set_ylim(0.5, 0.8)
ax1.set_yticks(np.linspace(0.5, 0.8, 4), np.linspace(0.5, 0.8, 4, dtype=np.float16))
ax2.set_yticks(np.linspace(0, 0.4, 3), np.linspace(0, 0.4, 3, dtype=np.float16))
ax3.set_ylim(0, 0.47)
ax3.set_yticks(np.linspace(0, 0.4, 3), np.linspace(0, 0.4, 3, dtype=np.float16))
ax3.axhline(0.25, ls='--', color='darkgray')

FONTSIZE = 20
plt.setp(ax1.get_legend().get_texts(), fontsize=FONTSIZE-4)

ax1.tick_params(labelsize=FONTSIZE-6)
ax2.tick_params(labelsize=FONTSIZE-6)
ax3.tick_params(labelsize=FONTSIZE-6)

ax3.set_xlabel(ax3.get_xlabel(), fontsize=FONTSIZE-6)

ax1.set_ylabel(ax1.get_ylabel(), fontsize=FONTSIZE-6)
ax2.set_ylabel(ax2.get_ylabel(), fontsize=FONTSIZE-6)
ax3.set_ylabel(ax3.get_ylabel(), fontsize=FONTSIZE-6)

for cont in ax1.containers:
    ax1.bar_label(cont, fmt='%.2f', fontsize=10)
for cont in ax2.containers:
    ax2.bar_label(cont, fmt='%.2f', fontsize=10)
for cont in ax3.containers:
    ax3.bar_label(cont, fmt='%.2f', fontsize=10)

plt.tight_layout()

# plt.savefig("figures_qa/mmlu.pdf", format='pdf', bbox_inches='tight')


## Multiple choice

In [None]:
seaborn.set_theme('paper', 'white')

fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1, figsize=(9, 4), sharex=True, gridspec_kw={'hspace': 0.2}, height_ratios=(0.34, 0.33, 0.33))

# PALETTE = ["#e8ddcb","#cdb380","#036564","#033649","#031634"]
PALETTE = "coolwarm"
palette = seaborn.color_palette(PALETTE)[2:]
seaborn.barplot(data_df, x='probe', y='auroc', hue='topic', order=('lr', 'mlp', 'mm', 'svm'), ax=ax1, errorbar=('se'), capsize=0.2, palette=palette)
seaborn.barplot(data_df, x='probe', y='ece',   hue='topic', order=('lr', 'mlp', 'mm', 'svm'), ax=ax2, errorbar=('se'), capsize=0.2, palette=palette, legend=False)
seaborn.barplot(data_df, x='probe', y='brier', hue='topic', order=('lr', 'mlp', 'mm', 'svm'), ax=ax3, errorbar=('se'), capsize=0.2, palette=palette, legend=False)

# ax1.axhline(0.5, ls='--', color='darkgray')
ax1.set_ylim(0.5)

ax2.set_ylim(0, 0.5)

for t, l in zip(ax1.legend_.texts, settings):
    t.set_text(l)
ax1.legend_.set_title('')
seaborn.move_legend(ax1, 'upper left', bbox_to_anchor=(0.23, 1.5), ncol=len(settings))

ax1.set_ylabel(r"AUROC$\uparrow$")
ax2.set_ylabel(r"ECE$\downarrow$")
ax3.set_ylabel(r"BS$\downarrow$")
ax3.set_xlabel("Probe")
ax3.set_xticklabels([x.get_text().upper() for x in ax3.get_xticklabels()])

ax1.tick_params(left=True)
ax2.tick_params(left=True)
ax3.tick_params(left=True)

ax1.set_ylim(0.5, 0.8)
ax1.set_yticks(np.linspace(0.5, 0.8, 4), np.linspace(0.5, 0.8, 4, dtype=np.float16))
ax2.set_ylim(0, 0.55)
ax2.set_yticks(np.linspace(0, 0.4, 3), np.linspace(0, 0.4, 3, dtype=np.float16))
ax3.set_ylim(0, 0.55)
ax3.set_yticks(np.linspace(0, 0.4, 3), np.linspace(0, 0.4, 3, dtype=np.float16))
ax3.axhline(0.25, ls='--', color='darkgray')

FONTSIZE = 20
plt.setp(ax1.get_legend().get_texts(), fontsize=FONTSIZE-4)

ax1.tick_params(labelsize=FONTSIZE-6)
ax2.tick_params(labelsize=FONTSIZE-6)
ax3.tick_params(labelsize=FONTSIZE-6)

ax3.set_xlabel(ax3.get_xlabel(), fontsize=FONTSIZE-6)

ax1.set_ylabel(ax1.get_ylabel(), fontsize=FONTSIZE-6)
ax2.set_ylabel(ax2.get_ylabel(), fontsize=FONTSIZE-6)
ax3.set_ylabel(ax3.get_ylabel(), fontsize=FONTSIZE-6)

for cont in ax1.containers:
    ax1.bar_label(cont, fmt='%.2f', fontsize=10)
for cont in ax2.containers:
    ax2.bar_label(cont, fmt='%.2f', fontsize=10)
for cont in ax3.containers:
    ax3.bar_label(cont, fmt='%.2f', fontsize=10)

plt.tight_layout()

plt.savefig("figures_qa/mmlu_mc.pdf", format='pdf', bbox_inches='tight')


# Generalization to QA (In-context)

In [None]:
import joblib
import numpy as np
import os
from sklearn.calibration import calibration_curve
from sklearn.metrics import roc_auc_score, roc_curve, brier_score_loss
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn
from itertools import chain

from probes import *
from utils import *


model_names_layer_indices = {
    # "Llama-2-7b-hf": 12,
    # "Llama-2-7b-chat-hf": 13,
    # "Llama-2-13b-hf": 13,
    # "Llama-2-13b-chat-hf": 13,
    "Meta-Llama-3.1-8B-hf": 12,
    # "Meta-Llama-3.1-8B-Instruct-hf": 13,
    # "Meta-Llama-3.1-70B-hf": 33,
    # "Meta-Llama-3.1-70B-Instruct-hf": 18,
    # "Mistral-Large-Instruct-2407": 43,
}

prompt_option = "no_prompt"

topics = (
    (
        "sciq_true_false_mc",
        "sciq_true_false_mc_TTT",
        "sciq_true_false_mc_TTF",
        "sciq_true_false_mc_FFT",
    ),
    (
        "boolq_true_false",
        "boolq_true_false_with_options",
        "boolq_true_false_with_options_T",
        "boolq_true_false_with_options_F",
    ),
    (
        "xsum_true_false",
        "xsum_true_false_T",
        "xsum_true_false_TT",
        "xsum_true_false_TTT",
    ),
)
settings = (
    ('0-shot', 'TTT', 'TTF', 'FFT'),
    ('no opt.', 'w/ opt.', 'T', 'F'),
    ('0-shot', 'T', 'TT', 'TTT'),
)

probe_names = ('lr', 'mlp', 'svm', 'mm')

results = {}
metrics = ['auroc', 'brier', 'ece']
for m in metrics:
    results[m] = []

for seed in (0,1,2):
    print(f"seed={seed}")
    for m in results:
        results[m].append({})
    
    for probe_name in probe_names:
        print(f"  probe={probe_name}")
        for m in results:
            results[m][-1][probe_name] = {}
        
        for model_name, layer_index in model_names_layer_indices.items():
            print(f"    model={model_name}")
            for m in results:
                results[m][-1][probe_name][model_name] = {}
            
            activations_dir = f"activations_and_labels/{model_name}/{prompt_option}"
            probes_dir = f"probes/{model_name}/seed={seed}"
            probe = joblib.load(os.path.join(probes_dir, f"{probe_name}.joblib"))
            for topic in chain(*topics):
                print(f"      topic={topic}")
                
                test_topic = topic
                test_activations = np.load(f"{activations_dir}/{test_topic}/acts_{layer_index}.npy")
                test_labels = pd.read_csv(f"data/{test_topic}.csv")['label']
                test_labels = np.array(test_labels.tolist())
                
                test_pos_label_indices = np.where(test_labels==1)[0].tolist()
                test_neg_label_indices = np.where(test_labels==0)[0].tolist()
                num_test_pos_labels = np.count_nonzero(test_labels==1)
                rng = np.random.RandomState(42)
                test_neg_label_indices_chosen = rng.choice(test_neg_label_indices, num_test_pos_labels)
                test_label_indices_chosen = np.concatenate([test_pos_label_indices, test_neg_label_indices_chosen])
                test_activations = test_activations[test_label_indices_chosen]
                test_labels = test_labels[test_label_indices_chosen]
            
                pos_label_probs = probe.predict_proba(test_activations)[:, 1]
                auroc = roc_auc_score(test_labels, pos_label_probs)
                results['auroc'][-1][probe_name][model_name][topic] = auroc
                
                brier =  brier_score_loss(test_labels, pos_label_probs)
                results['brier'][-1][probe_name][model_name][topic] = brier
                
                prob_true, prob_pred = calibration_curve(test_labels, pos_label_probs, n_bins=10, strategy='quantile')
                ece = calibration_error_expectation(prob_true, prob_pred)
                results['ece'][-1][probe_name][model_name][topic] = ece


In [None]:
data_df = {
    "probe": [],
    "topic": [],
    "auroc": [],
    "brier": [],
    "ece":   [],
}

target_model = 'Meta-Llama-3.1-8B-hf'
# target_model = 'Meta-Llama-3.1-70B-Instruct-hf'

for index in range(len(results[metrics[0]])):
    for probe_name in probe_names:
        for topic in chain(*topics):
            data_df['probe'].append(probe_name)
            data_df['topic'].append(topic)
            for m in metrics:
                data_df[m].append(results[m][index][probe_name][target_model][topic])
data_df = pd.DataFrame(data_df)

In [None]:
# deprecated

seaborn.set_theme('paper', 'whitegrid')

tasks = ("SciQ", "BoolQ", "XSum")
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(13, 3), sharey='row', gridspec_kw={'hspace': 0, 'wspace': 0.01}, height_ratios=(0.6, 0.4))

# PALETTE = ["#e8ddcb","#cdb380","#036564","#033649","#031634"]
PALETTE = 'coolwarm'
for i in range(len(topics)):
    topic_df = data_df.loc[data_df['topic'].isin(topics[i])]
    if i == 2:
        palette = seaborn.color_palette(PALETTE)
        palette = palette[0:1] + palette[-3:]
    else:
        palette = PALETTE
    seaborn.barplot(topic_df, x='probe', y='auroc', hue='topic', order=('lr', 'mlp', 'mm', 'svm'), ax=axes[0][i], errorbar=('se'), capsize=0.2, palette=palette, legend=False)
    seaborn.barplot(topic_df, x='probe', y='ece', hue='topic', order=('lr', 'mlp', 'mm', 'svm'), ax=axes[1][i], errorbar=('se'), capsize=0.2, palette=palette)

axes[0][0].tick_params(left=True)
axes[1][0].tick_params(left=True)

for ax, setting in zip(axes[1], settings):
    for t, l in zip(ax.legend_.texts, setting):
        t.set_text(l)
    ax.legend_.set_title('')
    ax.legend_.set_ncols(len(setting))
    seaborn.move_legend(ax, 'upper left', bbox_to_anchor=(0, -0.5))

for ax in axes[1]:
    ax.set_ylim(0, 0.44)
    ax.invert_yaxis()

FONTSIZE = 16
axes[0][0].set_ylabel(r"AUROC$\uparrow$", fontsize=FONTSIZE-2)
axes[1][0].set_ylabel(r"ECE$\downarrow$", fontsize=FONTSIZE-2)

axes[0][0].tick_params(labelsize=FONTSIZE-2)
axes[1][0].tick_params(labelsize=FONTSIZE-2)

for ax, task in zip(axes[0], tasks):
    ax.set_yticks(np.linspace(0,1,6), np.linspace(0,1,6, dtype=np.float16))
    ax.set_title(task, fontsize=FONTSIZE)
    ax.axhline(0.5, ls='--', color='darkgray')

for ax in axes[1]:
    plt.setp(ax.get_legend().get_texts(), fontsize=FONTSIZE-2)
    ax.tick_params(labelsize=FONTSIZE-2)
    ax.set_xlabel("Probe", fontsize=FONTSIZE-2)
    ax.set_xticklabels([x.get_text().upper() for x in ax.get_xticklabels()])
    ax.set_yticks(np.linspace(0.1, 0.4, 4), np.linspace(0.1, 0.4, 4, dtype=np.float16))
# ax2.set_xlabel(ax2.get_xlabel(), fontsize=FONTSIZE)
# ax1.set_ylabel(ax1.get_ylabel(), fontsize=FONTSIZE)
# ax2.set_ylabel(ax2.get_ylabel(), fontsize=FONTSIZE)

# for axes_row in axes:
#     for ax in axes_row:
#         for cont in ax.containers:
#             ax.bar_label(cont, fmt='%.2f', fontsize=5)

plt.tight_layout()

# plt.savefig("figures_qa/contextual_qa.pdf", format='pdf', bbox_inches='tight')


In [None]:
seaborn.set_theme('paper', 'whitegrid')

tasks = ("SciQ", "BoolQ", "XSum")
settings = (
    ('0-shot', 'TTT', 'TTF', 'FFT'),
    ('no opt.', 'w/ opt.', 'T', 'F'),
    ('0-shot', 'T', 'TT', 'TTT'),
)
fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(13, 3), sharex=True, sharey='row', gridspec_kw={'hspace': 0.2, 'wspace': 0.01}, height_ratios=(0.34, 0.33, 0.33))

PALETTE = 'coolwarm'
for i in range(len(topics)):
    topic_df = data_df.loc[data_df['topic'].isin(topics[i])]
    if i == 2:
        palette = seaborn.color_palette(PALETTE)
        palette = palette[1:2] + palette[-3:]
    else:
        palette = seaborn.color_palette(PALETTE)[1:]
    seaborn.barplot(topic_df, x='probe', y='auroc', hue='topic', order=('lr', 'mlp', 'mm', 'svm'), ax=axes[0][i], errorbar=('se'), capsize=0.2, palette=palette, legend=False)
    seaborn.barplot(topic_df, x='probe', y='ece',   hue='topic', order=('lr', 'mlp', 'mm', 'svm'), ax=axes[1][i], errorbar=('se'), capsize=0.2, palette=palette, legend=False)
    seaborn.barplot(topic_df, x='probe', y='brier', hue='topic', order=('lr', 'mlp', 'mm', 'svm'), ax=axes[2][i], errorbar=('se'), capsize=0.2, palette=palette)

axes[0][0].tick_params(left=True)
axes[1][0].tick_params(left=True)
axes[2][0].tick_params(left=True)

for ax, setting in zip(axes[2], settings):
    for t, l in zip(ax.legend_.texts, setting):
        t.set_text(l)
    ax.legend_.set_title('')
seaborn.move_legend(axes[2][0], 'upper left', bbox_to_anchor=(0.045, -0.6), ncol=len(setting), columnspacing=0.1)
seaborn.move_legend(axes[2][1], 'upper left', bbox_to_anchor=(0.05, -0.6), ncol=len(setting), columnspacing=0.1)
seaborn.move_legend(axes[2][2], 'upper left', bbox_to_anchor=(0.07, -0.6), ncol=len(setting), columnspacing=0.1)

for ax in axes[1]:
    ax.set_ylim(0, 0.44)

FONTSIZE = 16
axes[0][0].set_ylabel(r"AUROC$\uparrow$", fontsize=FONTSIZE-4)
axes[1][0].set_ylabel(r"ECE$\downarrow$", fontsize=FONTSIZE-4)
axes[2][0].set_ylabel(r"BS$\downarrow$", fontsize=FONTSIZE-4)

axes[0][0].tick_params(labelsize=FONTSIZE-6)

for ax, task in zip(axes[0], tasks):
    ax.set_ylim(0.5, 1.0)
    ax.set_yticks(np.linspace(0.5,1,6), np.linspace(0.5,1,6, dtype=np.float16))
    ax.set_title(task, fontsize=FONTSIZE-4)
    # ax.axhline(0.5, ls='--', color='darkgray')

for ax in axes[1]:
    ax.tick_params(labelsize=FONTSIZE-6, top=False, bottom=False, which='both')
    ax.set_yticks(np.linspace(0., 0.4, 5), np.linspace(0., 0.4, 5, dtype=np.float16))

for ax in axes[2]:
    plt.setp(ax.get_legend().get_texts(), fontsize=FONTSIZE-4)
    ax.tick_params(labelsize=FONTSIZE-6, top=False, bottom=False, which='both')
    ax.set_xlabel("Probe", fontsize=FONTSIZE-6)
    ax.set_xticklabels([x.get_text().upper() for x in ax.get_xticklabels()])
    ax.set_yticks(np.linspace(0., 0.4, 5), np.linspace(0., 0.4, 5, dtype=np.float16))
    ax.axhline(0.25, ls='--', color='darkgray')

# for axes_row in axes:
#     for ax in axes_row:
#         for cont in ax.containers:
#             ax.bar_label(cont, fmt='%.2f', fontsize=5)

plt.tight_layout()

plt.savefig("figures_qa/contextual_qa.pdf", format='pdf', bbox_inches='tight')


In [None]:
ax1.axhline(0.5, ls='--', color='darkgray')
# ax2.axhline(0.25, ls='--', color='darkgray')

ax2.invert_yaxis()

for t, l in zip(ax1.legend_.texts, settings):
    t.set_text(l)
ax1.legend_.set_title('')
seaborn.move_legend(ax1, 'lower right')
ax2.legend_.remove()

ax1.set_ylabel(r"AUROC$\uparrow$")
ax2.set_ylabel(r"ECE$\downarrow$")
ax2.set_xlabel("Probe")
ax2.set_xticklabels([x.get_text().upper() for x in ax2.get_xticklabels()])
ax1.tick_params(left=True)
ax2.tick_params(left=True)

# ax1.set_ylim(0, 1.1)


ax2.set_yticks(np.linspace(0.1, 0.3, 3), np.linspace(0.1, 0.3, 3, dtype=np.float16))

# ax2.set_yticks((0.1, 0.2, 0.3, 0.4), (0.1, 0.2, 0.3, 0.4)) # sciq

FONTSIZE = 20
plt.setp(ax1.get_legend().get_texts(), fontsize=FONTSIZE-2)
ax2.tick_params(labelsize=FONTSIZE-2)
ax1.tick_params(labelsize=FONTSIZE-2)
ax2.set_xlabel(ax2.get_xlabel(), fontsize=FONTSIZE)
ax1.set_ylabel(ax1.get_ylabel(), fontsize=FONTSIZE)
ax2.set_ylabel(ax2.get_ylabel(), fontsize=FONTSIZE)

for cont in ax1.containers:
    ax1.bar_label(cont, fmt='%.2f', fontsize=FONTSIZE-8)
for cont in ax2.containers:
    ax2.bar_label(cont, fmt='%.2f', fontsize=FONTSIZE-8)

plt.tight_layout()

plt.savefig("figures_qa/xsum.pdf", format='pdf')


# TriviaQA

In [None]:
from itertools import chain
import joblib
import numpy as np
import os
from sklearn.calibration import calibration_curve
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, roc_curve, brier_score_loss
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn

from probes import *
from utils import *


model_names_layer_indices = {
    # "Llama-2-7b-hf": 12,
    # "Llama-2-7b-chat-hf": 13,
    # "Llama-2-13b-hf": 13,
    # "Llama-2-13b-chat-hf": 13,
    "Meta-Llama-3.1-8B-hf": 12,
    # "Meta-Llama-3.1-8B-Instruct-hf": 13,
    # "Meta-Llama-3.1-70B-hf": 33,
    # "Meta-Llama-3.1-70B-Instruct-hf": 37,
    # "Mistral-Large-Instruct-2407": 43,
}

prompt_option = "no_prompt"

topics = (
    "triviaqa_true_false_Meta-Llama-3.1-8B-hf-shots=5",
    "triviaqa_true_false_Meta-Llama-3.1-8B-hf-shots=20",
)
settings = ['5-shot', '20-shot']

probe_names = ('lr', 'mlp', 'svm', 'mm')

results = {}
metrics = ['auroc', 'brier', 'ece']
for m in metrics:
    results[m] = []

for seed in (0,1,2):
    print(f"seed={seed}")
    for m in results:
        results[m].append({})
    
    for probe_name in probe_names:
        print(f"  probe={probe_name}")
        for m in results:
            results[m][-1][probe_name] = {}
        
        for model_name, layer_index in model_names_layer_indices.items():
            print(f"    model={model_name}")
            for m in results:
                results[m][-1][probe_name][model_name] = {}
            
            activations_dir = f"activations_and_labels/{model_name}/{prompt_option}"
            probes_dir = f"probes/{model_name}/seed={seed}"
            probe = joblib.load(os.path.join(probes_dir, f"{probe_name}.joblib"))
            for topic in topics:
                print(f"      topic={topic}")
                
                test_topic = topic
                test_activations = np.load(f"{activations_dir}/{test_topic}/acts_{layer_index}.npy")
                test_labels = pd.read_csv(f"data/{test_topic}.csv")['label']
                test_labels = np.array(test_labels.tolist())
                
                # test_pos_label_indices = np.where(test_labels==1)[0].tolist()
                # test_neg_label_indices = np.where(test_labels==0)[0].tolist()
                # num_test_pos_labels = np.count_nonzero(test_labels==1)
                # rng = np.random.RandomState(42)
                # test_neg_label_indices_chosen = rng.choice(test_neg_label_indices, num_test_pos_labels)
                # test_label_indices_chosen = np.concatenate([test_pos_label_indices, test_neg_label_indices_chosen])
                # test_activations = test_activations[test_label_indices_chosen]
                # test_labels = test_labels[test_label_indices_chosen]
            
                pos_label_probs = probe.predict_proba(test_activations)[:, 1]
                auroc = roc_auc_score(test_labels, pos_label_probs)
                results['auroc'][-1][probe_name][model_name][topic] = auroc
                
                brier =  brier_score_loss(test_labels, pos_label_probs)
                results['brier'][-1][probe_name][model_name][topic] = brier
                
                prob_true, prob_pred = calibration_curve(test_labels, pos_label_probs, n_bins=10, strategy='quantile')
                ece = calibration_error_expectation(prob_true, prob_pred)
                results['ece'][-1][probe_name][model_name][topic] = ece


In [None]:
data_df = {
    "probe": [],
    "topic": [],
    "auroc": [],
    "brier": [],
    "ece":   [],
}

target_model = 'Meta-Llama-3.1-8B-hf'

for index in range(len(results[metrics[0]])):
    for probe_name in probe_names:
        for topic in topics:
            data_df['probe'].append(probe_name)
            data_df['topic'].append(topic)
            for m in metrics:
                data_df[m].append(results[m][index][probe_name][target_model][topic])
data_df = pd.DataFrame(data_df)

In [None]:
seaborn.set_theme('paper', 'white')

fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1, figsize=(9, 4), sharex=True, gridspec_kw={'hspace': 0.2}, height_ratios=(0.34, 0.33, 0.33))

# PALETTE = ["#e8ddcb","#cdb380","#036564","#033649","#031634"]
PALETTE = "coolwarm"

seaborn.barplot(data_df, x='probe', y='auroc', hue='topic', order=('lr', 'mlp', 'mm', 'svm'), ax=ax1, errorbar=('se'), capsize=0.2, palette=PALETTE)
seaborn.barplot(data_df, x='probe', y='ece',   hue='topic', order=('lr', 'mlp', 'mm', 'svm'), ax=ax2, errorbar=('se'), capsize=0.2, palette=PALETTE, legend=False)
seaborn.barplot(data_df, x='probe', y='brier', hue='topic', order=('lr', 'mlp', 'mm', 'svm'), ax=ax3, errorbar=('se'), capsize=0.2, palette=PALETTE, legend=False)

# ax1.axhline(0.5, ls='--', color='darkgray')
ax1.set_ylim(0.5)

ax2.set_ylim(0, 0.5)

for t, l in zip(ax1.legend_.texts, settings):
    t.set_text(l)
ax1.legend_.set_title('')
seaborn.move_legend(ax1, 'upper left', bbox_to_anchor=(0.31, 1.5), ncol=len(settings))

ax1.set_ylabel(r"AUROC$\uparrow$")
ax2.set_ylabel(r"ECE$\downarrow$")
ax3.set_ylabel(r"BS$\downarrow$")
ax3.set_xlabel("Probe")
ax3.set_xticklabels([x.get_text().upper() for x in ax3.get_xticklabels()])

ax1.tick_params(left=True)
ax2.tick_params(left=True)
ax3.tick_params(left=True)

ax1.set_ylim(0.5, 0.82)
ax1.set_yticks(np.linspace(0.5, 0.8, 4), np.linspace(0.5, 0.8, 4, dtype=np.float16))
ax2.set_yticks(np.linspace(0, 0.4, 3), np.linspace(0, 0.4, 3, dtype=np.float16))
ax3.set_ylim(0, 0.47)
ax3.set_yticks(np.linspace(0, 0.4, 3), np.linspace(0, 0.4, 3, dtype=np.float16))
ax3.axhline(0.25, ls='--', color='darkgray')

FONTSIZE = 20
plt.setp(ax1.get_legend().get_texts(), fontsize=FONTSIZE-4)

ax1.tick_params(labelsize=FONTSIZE-6)
ax2.tick_params(labelsize=FONTSIZE-6)
ax3.tick_params(labelsize=FONTSIZE-6)

ax3.set_xlabel(ax3.get_xlabel(), fontsize=FONTSIZE-6)

ax1.set_ylabel(ax1.get_ylabel(), fontsize=FONTSIZE-6)
ax2.set_ylabel(ax2.get_ylabel(), fontsize=FONTSIZE-6)
ax3.set_ylabel(ax3.get_ylabel(), fontsize=FONTSIZE-6)

for cont in ax1.containers:
    ax1.bar_label(cont, fmt='%.3f', fontsize=10)
for cont in ax2.containers:
    ax2.bar_label(cont, fmt='%.3f', fontsize=10)
for cont in ax3.containers:
    ax3.bar_label(cont, fmt='%.3f', fontsize=10)

plt.tight_layout()

plt.savefig("figures_qa/triviaqa.pdf", format='pdf', bbox_inches='tight')


# TriviaQA Selective QA

In [None]:
import joblib
import numpy as np
import os
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn

from probes import *
from utils import *


prompt_option = "no_prompt"

model_names_layer_indices = {
    # "Llama-2-7b-hf": 12,
    # "Llama-2-7b-chat-hf": 13,
    "Llama-2-13b-hf": 13,
    # "Llama-2-13b-chat-hf": 13,
    "Meta-Llama-3.1-8B-hf": 12,
    # "Meta-Llama-3.1-8B-Instruct-hf": 13,
    # "Meta-Llama-3.1-70B-hf": 33,
    # "Meta-Llama-3.1-70B-Instruct-hf": 37,
    # "Mistral-Large-Instruct-2407": 37,
}

vanilla_accs = {
    "Llama-2-13b-hf": 0.5824,
    "Meta-Llama-3.1-8B-hf": 0.55168,
}

disp = {
    'Method': [], 'Exact Match': [], 'Model': []
}
for model_name, layer_index in model_names_layer_indices.items():
    print(f"model={model_name}")
    activations_dir = f"activations_and_labels/{model_name}/{prompt_option}"

    test_topic = f"triviaqa_true_false_{model_name}"
    test_activations = np.load(f"{activations_dir}/{test_topic}/acts_{layer_index}.npy")
    test_labels = pd.read_csv(f"data/{test_topic}.csv")['label']
    test_labels = np.array(test_labels.tolist())

    probe_names = ('lr', 'mlp', 'svm', 'mm')

    disp['Method'].append("Vanilla")
    disp['Exact Match'].append(vanilla_accs[model_name])
    disp['Model'].append(model_name)

    for seed in (0,1,2):
        print(f"  seed={seed}")
        probes_dir = f"probes/{model_name}/seed={seed}"
        lr = joblib.load(os.path.join(probes_dir, "lr.joblib"))
        mlp = joblib.load(os.path.join(probes_dir, "mlp.joblib"))
        svm = joblib.load(os.path.join(probes_dir, "svm.joblib"))
        mm = joblib.load(os.path.join(probes_dir, "mm.joblib"))

        for probe_name in probe_names:
            print(f"    probe={probe_name}")
            probe = locals()[probe_name]
            pos_probs = probe.predict_proba(test_activations)[:,1]
            selected_answers = test_labels[pos_probs>0.5]
            cacc = np.sum(selected_answers) / len(selected_answers)
            disp['Method'].append(probe_name.upper())
            disp['Exact Match'].append(cacc)
            disp['Model'].append(model_name.replace("-hf","").replace("Meta-","").replace('b','B').replace("-chat","-Chat"))


In [None]:
seaborn.set_theme('paper', 'whitegrid')

TICK_LABEL_SIZE = 16
plt.figure(figsize=(6,3.5))
ax = seaborn.barplot(disp, x='Method', y='Exact Match', hue='Model', legend=True, width=0.6, palette=['b', 'g'], errorbar=('sd'), capsize=0.2)
# ax.set_xticklabels(ax.get_xticklabels(), fontsize=TICK_LABEL_SIZE)
# ax.set_yticklabels(ax.get_yticklabels(), fontsize=TICK_LABEL_SIZE)
ax.set_xlabel(ax.get_xlabel(), fontsize=TICK_LABEL_SIZE)
ax.set_ylabel("EM", fontsize=TICK_LABEL_SIZE)
ax.tick_params(left=True, labelsize=TICK_LABEL_SIZE)
# ax.set_yticks(np.linspace(0,0.9,4), np.linspace(0,0.9,4))
ax.set_ylim(0, .85)
# for cont in ax.containers:
#     ax.bar_label(cont, fmt='%.3f', fontsize=TICK_LABEL_SIZE)
plt.tight_layout()
plt.savefig("figures_selective_qa/selective_qa.pdf", format='pdf')

In [None]:
disp

# Framework figure

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sklearn.linear_model
plt.rc('text')

pts = np.loadtxt('linpts.txt')
X = pts[:,:2]
Y = pts[:,2].astype('int')

# Fit the data to a logistic regression model.
clf = sklearn.linear_model.LogisticRegression()
clf.fit(X, Y)

# Retrieve the model parameters.
b = clf.intercept_[0]
w1, w2 = clf.coef_.T
# Calculate the intercept and gradient of the decision boundary.
c = -b/w2
m = -w1/w2

# Plot the data and the classification with the decision boundary.
xmin, xmax = -0.6, 1.4
ymin, ymax = -0.8, 2.2
xd = np.array([xmin, xmax])
yd = m*xd + c
plt.plot(xd, yd, 'k', lw=1, ls='--')
plt.fill_between(xd, yd, ymin, color='tab:blue', alpha=0.2)
plt.fill_between(xd, yd, ymax, color='tab:orange', alpha=0.2)

plt.scatter(*X[Y==0].T, s=8, alpha=0.5)
plt.scatter(*X[Y==1].T, s=8, alpha=0.5)
plt.xlim(xmin, xmax)
plt.ylim(ymin, ymax)

plt.axis('off')

plt.text(xmin+0.3, ymin+0.3, 'False', fontsize=40)
plt.text(xmax-0.6, ymax-0.5, 'True', fontsize=40)

plt.savefig('hyperplane.pdf', dpi=400, bbox_inches='tight')
