In [1]:
import sys
sys.path.append("../src/")

# Enable hot autoreload
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import numpy as np
from sklearn.metrics import roc_curve, auc
import torch
import pandas as pd
from tqdm.notebook import tqdm
import itertools
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM
from utils import compute_perplexity_df

In [31]:
LLAMA_TOKENIZER_PATH = "<LLAMA_TOKENIZER_PATH>"
LLAMA_MODEL_PATH = "<LLAMA_MODEL_PATH>"

# pickled pd.DataFrame, produced by inject_traps.py
CANARY_INFO_PATH = "<CANARY_INFO_PATH>"

# We've generated additional non-member canaries with the scripts `gen_traps.py`
NON_MEMBERS_50_SMALL_PATH_TEMPLATE = "<PATH_TO_NON_MEMBER_TRAPS_SEQ_LEN_%d>"

In [None]:
llama_tokenizer = LlamaTokenizer.from_pretrained(LLAMA_TOKENIZER_PATH, torch_dtype=torch.float16)
llama_model = LlamaForCausalLM.from_pretrained(LLAMA_MODEL_PATH)

croissant_tokenizer = AutoTokenizer.from_pretrained("croissantllm/base_190k")
croissant_model = AutoModelForCausalLM.from_pretrained("croissantllm/base_190k")

croissant_tokenizer.pad_token = croissant_tokenizer.eos_token
llama_tokenizer.pad_token = llama_tokenizer.eos_token

croissant_model = croissant_model.to("cuda:0")
llama_model = llama_model.to("cuda:1")

In [5]:
with open(CANARY_INFO_PATH, "rb") as f:
    canary_info = pickle.load(f)
canary_info["raw_canaries"] = llama_tokenizer.batch_decode(canary_info.canary_tokens)

non_members = {}
for seq_len in [25, 50, 100]:
    small_path = NON_MEMBERS_50_SMALL_PATH_TEMPLATE % seq_len
    with open(small_path, "rb") as f:
        non_members[seq_len] = pickle.load(f)

In [6]:
seq_len_set = set(canary_info.seq_len)
ppl_set = set(canary_info.ppl_bucket)
n_rep_set = set(canary_info.n_rep)

In [None]:
df_res = pd.DataFrame()

for seq_len, ppl, n_rep in tqdm(list(itertools.product(seq_len_set, ppl_set, n_rep_set))):
    df_filter = canary_info[
        (canary_info.seq_len == seq_len)
        & (canary_info.ppl_bucket == ppl)
        & (canary_info.n_rep == n_rep)
    ]

    df_res_tmp = compute_perplexity_df(
        llama_model=llama_model,
        croissant_model=croissant_model,
        llama_tokenizer=llama_tokenizer,
        croissant_tokenizer=croissant_tokenizer,
        raw_canaries=df_filter.raw_canaries,
    )

    df_res_tmp["seq_len"] = seq_len
    df_res_tmp["ppl"] = ppl
    df_res_tmp["n_rep"] = n_rep

    df_res = pd.concat([df_res, df_res_tmp])


for seq_len, ppl in tqdm(list(itertools.product(seq_len_set, ppl_set))):
    key = (ppl * 10 + 1, ppl * 10 + 11)
    raw_canaries = llama_tokenizer.batch_decode(non_members[seq_len][key][:, 1:])

    df_res_tmp = compute_perplexity_df(
        llama_model=llama_model,
        croissant_model=croissant_model,
        llama_tokenizer=llama_tokenizer,
        croissant_tokenizer=croissant_tokenizer,
        raw_canaries=raw_canaries,
    )
    df_res_tmp["seq_len"] = seq_len
    df_res_tmp["ppl"] = ppl
    df_res_tmp["n_rep"] = 0

    df_res = pd.concat([df_res, df_res_tmp])
    
df_res["ratio"] = df_res.croissant_ppl / df_res.llama_ppl

In [None]:
fig, axs = plt.subplots(3, 4, figsize=(18, 12))

aucs = []

for i, seq_len in enumerate([25, 50, 100]):
    for j, n_rep in enumerate([1, 10, 100, 1000]):
        y_raw = []
        y_true = []
        y_ratio = []
        y_minkprob = []

        members_tmp = df_res[(df_res.seq_len == seq_len) & (df_res.n_rep == n_rep)]
        non_members_tmp = df_res[(df_res.seq_len == seq_len) & (df_res.n_rep == 0)]

        n = min(len(members_tmp), len(non_members_tmp))
        members_tmp = members_tmp[:n]
        non_members_tmp = non_members_tmp[:n]

        y_raw.extend(np.log(members_tmp["croissant_ppl"]))
        y_raw.extend(np.log(non_members_tmp["croissant_ppl"]))
        y_ratio.extend(members_tmp["ratio"])
        y_ratio.extend(non_members_tmp["ratio"])
        y_minkprob.extend(-members_tmp["minkprob"])
        y_minkprob.extend(-non_members_tmp["minkprob"])

        y_true.extend([0] * len(members_tmp))
        y_true.extend([1] * len(non_members_tmp))

        fpr_baseline, tpr_baseline, _ = roc_curve(y_true, y_raw)
        fpr_ratio, tpr_ratio, _ = roc_curve(y_true, y_ratio)
        fpr_minkprob, tpr_minkprob, _ = roc_curve(y_true, y_minkprob)

        auc_ratio = auc(fpr_ratio, tpr_ratio)
        auc_baseline = auc(fpr_baseline, tpr_baseline)
        auc_minkprob = auc(fpr_minkprob, tpr_minkprob)

        aucs.append((seq_len, n_rep, auc_ratio))

        ax = axs[i, j]

        ax.plot(fpr_baseline, tpr_baseline, label=f"LOSS (AUC={auc_baseline:.3f})")
        ax.plot(fpr_ratio, tpr_ratio, label=f"Ratio (AUC={auc_ratio:.3f})")
        ax.plot(fpr_minkprob, tpr_minkprob, label=f"Min k-prob (AUC={auc_minkprob:.3f})")
        ax.plot([0, 1], [0, 1], color="navy", lw=1, linestyle="--")
        ax.set_xlabel("False Positive Rate")
        if i == 0:
            ax.set_ylabel("True Positive Rate")
        ax.set_title(f"seq_len={seq_len}, n_rep = {n_rep}")
        ax.legend(loc="lower right")

plt.show()