In [1]:
import sys
sys.path.append("../src/")

# Enable hot autoreload
%load_ext autoreload
%autoreload 2

In [13]:
import pickle
import numpy as np
from sklearn.metrics import roc_auc_score
import torch
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM
from utils import compute_perplexity_df
import itertools
from collections import defaultdict
from tqdm.notebook import tqdm

In [None]:
LLAMA_TOKENIZER_PATH = "<LLAMA_TOKENIZER_PATH>"
LLAMA_MODEL_PATH = "<LLAMA_MODEL_PATH>"

# pickled pd.DataFrame, produced by inject_traps.py
TRAP_INFO_PATH = "<TRAP_INFO_PATH>"

# We've generated additional non-member canaries with the scripts `gen_traps.py`
NON_MEMBERS_PATH_TEMPLATE = "<PATH_TO_NON_MEMBER_TRAPS_SEQ_LEN_%d>"

In [None]:
llama_model = LlamaForCausalLM.from_pretrained(LLAMA_MODEL_PATH)

llama_tokenizer = LlamaTokenizer.from_pretrained(LLAMA_TOKENIZER_PATH, torch_dtype=torch.float16)
llama_tokenizer.pad_token = llama_tokenizer.eos_token

croissant_tokenizer = AutoTokenizer.from_pretrained("croissantllm/base_190k")
croissant_tokenizer.pad_token = croissant_tokenizer.eos_token

llama_model = llama_model.to("cuda:1")

In [5]:
with open(TRAP_INFO_PATH, "rb") as f:
    trap_info = pickle.load(f)
trap_info["raw_canaries"] = llama_tokenizer.batch_decode(trap_info.canary_tokens)

non_members = {}
for seq_len in [25, 50, 100]:
    nm_path = NON_MEMBERS_PATH_TEMPLATE % seq_len
    with open(nm_path, "rb") as f:
        non_members[seq_len] = pickle.load(f)

In [21]:
seq_len_set = set(trap_info.seq_len)
ppl_set = set(trap_info.ppl_bucket)
n_rep = 1000

In [26]:
def make_df_res(croissant_model):
    df_res = pd.DataFrame()

    for seq_len, ppl in tqdm(list(itertools.product(seq_len_set, ppl_set))):
        df_filter = trap_info[
            (trap_info.seq_len == seq_len)
            & (trap_info.ppl_bucket == ppl)
            & (trap_info.n_rep == n_rep)
        ]

        df_res_tmp = compute_perplexity_df(
            llama_model=llama_model,
            croissant_model=croissant_model,
            llama_tokenizer=llama_tokenizer,
            croissant_tokenizer=croissant_tokenizer,
            raw_canaries=df_filter.raw_canaries,
        )

        df_res_tmp["seq_len"] = seq_len
        df_res_tmp["ppl"] = ppl
        df_res_tmp["n_rep"] = n_rep

        df_res = pd.concat([df_res, df_res_tmp])


    for seq_len, ppl in tqdm(list(itertools.product(seq_len_set, ppl_set))):
        key = (ppl * 10 + 1, ppl * 10 + 11)
        raw_canaries = llama_tokenizer.batch_decode(non_members[seq_len][key][:, 1:])

        df_res_tmp = compute_perplexity_df(
            llama_model=llama_model,
            croissant_model=croissant_model,
            llama_tokenizer=llama_tokenizer,
            croissant_tokenizer=croissant_tokenizer,
            raw_canaries=raw_canaries,
        )
        df_res_tmp["seq_len"] = seq_len
        df_res_tmp["ppl"] = ppl
        df_res_tmp["n_rep"] = 0

        df_res = pd.concat([df_res, df_res_tmp])

    df_res["ratio"] = df_res.croissant_ppl / df_res.llama_ppl
    return df_res

In [27]:
def compute_auc(df_res, seq_len=100, n_rep=1000):
    y_score, y_true = [], []
    df_members = df_res[(df_res.seq_len == seq_len) & (df_res.n_rep == n_rep)]
    score_members = np.log(df_members["croissant_ppl"]) / np.log(df_members["llama_ppl"])
    y_score.extend(list(score_members))
    y_true.extend([1] * len(score_members))

    df_non_members = df_res[(df_res.seq_len == seq_len) & (df_res.n_rep == 0)]
    score_non_members = np.log(df_non_members["croissant_ppl"]) / np.log(df_non_members["llama_ppl"])
    y_score.extend(list(score_non_members))
    y_true.extend([0] * len(score_non_members))

    auc = roc_auc_score(y_true, -np.array(y_score))
    return auc

In [None]:
checkpoints = (5, 25, 45, 65, 85, 105, 125, 145, 165, 190)
all_aucs = defaultdict(list)

for seq_len in (25, 50, 100):
    for checkpoint in tqdm(checkpoints):
        checkpoint_name = f"croissantllm/base_{checkpoint}k"
        croissant_model = AutoModelForCausalLM.from_pretrained(
            checkpoint_name, cache_dir="/home/igor/rds/ephemeral/.huggingface").to("cuda:0")

        df_res = make_df_res(croissant_model)

        auc = compute_auc(df_res=df_res, seq_len=seq_len)
        all_aucs[seq_len].append(auc)

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(5, 5), sharey=True)

# Define custom color combinations
custom_colors = ['darkred', 'darkblue', 'darkgreen']

for j, seq_len in enumerate((25, )):
    aucs_to_plot = all_aucs[seq_len]
    axes.plot(checkpoints[:len(aucs_to_plot)], aucs_to_plot,
              label="$L_{ref}$ = "+str(seq_len), marker="o", markersize=8,
              alpha=1, linewidth=2, color=custom_colors[j], markeredgecolor='white', markeredgewidth=2)

# axes.set_title("Measuring memorization during training", fontsize = 13)
axes.axhline(0.5, linestyle='--', linewidth=2, c='black', alpha=1, label='Random guess baseline')

axes.set_xlim(0, 200)
axes.set_xticks(np.arange(0, 201, 20))
axes.set_xticklabels(np.arange(0, 201, 20))
axes.set_xlabel("Training steps (in thousands)", fontsize=15)
axes.set_ylabel("AUC", fontsize=15)
axes.set_ylim(0.3, 0.9)
axes.grid()
axes.legend(loc='upper left', fontsize=12)
plt.tight_layout()
plt.show()