In [1]:
import sys
sys.path.append("../src/")

# Enable hot autoreload
%load_ext autoreload
%autoreload 2

In [16]:
import pickle
import numpy as np
from sklearn.metrics import roc_curve, auc
import torch
import pandas as pd
from scipy.stats import pearsonr, linregress
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM
from utils import compute_perplexity_df

In [None]:
LLAMA_TOKENIZER_PATH = "<LLAMA_TOKENIZER_PATH>"
LLAMA_MODEL_PATH = "<LLAMA_MODEL_PATH>"

# pickled pd.DataFrame, produced by inject_traps.py
TRAP_INFO_PATH = "<TRAP_INFO_PATH>"

# We've generated additional non-member canaries with the scripts `gen_traps.py`
NON_MEMBERS_PATH_TEMPLATE = "<PATH_TO_NON_MEMBER_TRAPS_SEQ_LEN_%d>"

In [None]:
llama_tokenizer = LlamaTokenizer.from_pretrained(LLAMA_TOKENIZER_PATH, torch_dtype=torch.float16)
llama_model = LlamaForCausalLM.from_pretrained(LLAMA_MODEL_PATH)

croissant_tokenizer = AutoTokenizer.from_pretrained("croissantllm/base_190k")
croissant_model = AutoModelForCausalLM.from_pretrained("croissantllm/base_190k")

croissant_tokenizer.pad_token = croissant_tokenizer.eos_token
llama_tokenizer.pad_token = llama_tokenizer.eos_token

croissant_model = croissant_model.to("cuda:0")
llama_model = llama_model.to("cuda:1")

In [20]:
with open(TRAP_INFO_PATH, "rb") as f:
    trap_info = pickle.load(f)
trap_info["raw_canaries"] = llama_tokenizer.batch_decode(trap_info.canary_tokens)

non_members = {}
for seq_len in [25, 50, 100]:
    nm_path = NON_MEMBERS_PATH_TEMPLATE % seq_len
    with open(nm_path, "rb") as f:
        non_members[seq_len] = pickle.load(f)

In [22]:
n_rep = 1000
seq_len = 100
ppl_set = set(trap_info.ppl_bucket)

In [1]:
df_res = pd.DataFrame()

for ppl in ppl_set:
    df_filter = trap_info[
        (trap_info.seq_len == seq_len)
        & (trap_info.ppl_bucket == ppl)
        & (trap_info.n_rep == n_rep)
    ]

    df_res_tmp = compute_perplexity_df(
        llama_model=llama_model,
        croissant_model=croissant_model,
        llama_tokenizer=llama_tokenizer,
        croissant_tokenizer=croissant_tokenizer,
        raw_canaries=df_filter.raw_canaries,
    )

    df_res_tmp["seq_len"] = seq_len
    df_res_tmp["ppl"] = ppl
    df_res_tmp["n_rep"] = n_rep

    df_res = pd.concat([df_res, df_res_tmp])


for ppl in ppl_set:
    key = (ppl * 10 + 1, ppl * 10 + 11)
    raw_canaries = llama_tokenizer.batch_decode(non_members[seq_len][key][:, 1:])

    df_res_tmp = compute_perplexity_df(
        llama_model=llama_model,
        croissant_model=croissant_model,
        llama_tokenizer=llama_tokenizer,
        croissant_tokenizer=croissant_tokenizer,
        raw_canaries=raw_canaries,
    )
    df_res_tmp["seq_len"] = seq_len
    df_res_tmp["ppl"] = ppl
    df_res_tmp["n_rep"] = 0

    df_res = pd.concat([df_res, df_res_tmp])
    
df_res["ratio"] = df_res.croissant_ppl / df_res.llama_ppl

In [24]:
data = []
for ppl in range(10):
    y_true = []
    y_ratio = []

    members_tmp = df_res[(df_res.seq_len == seq_len) & (df_res.n_rep == n_rep) & (df_res.ppl == ppl)]
    non_members_tmp = df_res[(df_res.seq_len == seq_len) & (df_res.n_rep == 0) & (df_res.ppl == ppl)]

    y_ratio.extend(members_tmp["ratio"])
    y_ratio.extend(non_members_tmp["ratio"])
    y_true.extend([0] * len(members_tmp))
    y_true.extend([1] * len(non_members_tmp))
    fpr_ratio, tpr_ratio, _ = roc_curve(y_true, y_ratio)
    auc_ratio = auc(fpr_ratio, tpr_ratio)

    data.append((ppl, auc_ratio))

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(5, 5), sharey=True)

x, y = zip(*data)
x = np.array(x) * 10
slope, intercept, r_value, p_value, std_err = linregress(x, y)
pearson_coeff, pearson_p_value = pearsonr(x, y)

axes.scatter(
    x + 5,
    y,
    label=r"$L_\text{ref} = %d, n_\text{rep} = %d$" % (seq_len, n_rep),
    s=60,
    c="darkgreen",
    linewidths=0.5,
    marker="o",
)
axes.plot(x + 5, slope * np.array(x) + intercept, color="dimgrey", linestyle="--", label="Linear fit")
axes.grid()  # Move this line below plt.scatter()
axes.set_axisbelow(True)


axes.axhline(0.5, linestyle="--", linewidth=2, c="black", alpha=1, label="Random guess baseline")
plt.xlabel("Trap sequence perplexity", fontsize=15)
plt.ylabel("AUC", fontsize=17)


axes.set_ylim(0.3, 0.9)
axes.set_xlim(0, 100)
axes.legend(loc="lower right", fontsize=10)
plt.tight_layout()
plt.show()

print(f"seq_len={seq_len}, n_rep={n_rep}, pearson = {pearson_coeff:.3f}, p_value = {pearson_p_value:.2e}")