In [1]:
import sys
sys.path.append("../src/")

# Enable hot autoreload
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import numpy as np
from sklearn.metrics import roc_auc_score
import torch
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, BitsAndBytesConfig
from utils import compute_perplexity_df
from tqdm.notebook import tqdm

In [None]:
LLAMA_TOKENIZER_PATH = "<LLAMA_TOKENIZER_PATH>"
LLAMA_MODEL_PATH = "<LLAMA_MODEL_PATH>"

# pickled pd.DataFrame, produced by inject_traps.py
TRAP_INFO_PATH = "<TRAP_INFO_PATH>"

# We've generated additional non-member trap sequences with the scripts `gen_traps.py`
NON_MEMBERS_PATH_TEMPLATE = "<PATH_TO_NON_MEMBER_TRAPS_SEQ_LEN_%d>"

In [None]:
llama_tokenizer = LlamaTokenizer.from_pretrained(LLAMA_TOKENIZER_PATH, torch_dtype=torch.float16)
llama_tokenizer.pad_token = llama_tokenizer.eos_token

llama_model = LlamaForCausalLM.from_pretrained(LLAMA_MODEL_PATH)
llama_model = llama_model.to("cuda:1")

croissant_tokenizer = AutoTokenizer.from_pretrained("croissantllm/base_190k")
croissant_tokenizer.pad_token = croissant_tokenizer.eos_token

In [5]:
with open(TRAP_INFO_PATH, "rb") as f:
    trap_info = pickle.load(f)
trap_info["raw_traps"] = llama_tokenizer.batch_decode(trap_info.trap_tokens)

non_members = {}
for seq_len in [25, 50, 100]:
    nm_path = NON_MEMBERS_PATH_TEMPLATE % seq_len
    with open(nm_path, "rb") as f:
        non_members[seq_len] = pickle.load(f)

In [6]:
def get_croissant_model(croissant_checkpoint: str, torch_dtype: str):
    model_load_kwargs = {}
    if torch_dtype == 'int8':
        model_load_kwargs["quantization_config"] = BitsAndBytesConfig(
            load_in_8bit=True,
            llm_int8_threshold=0.0,
        )
    elif torch_dtype == 'int4':
        model_load_kwargs["quantization_config"] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
        )
    elif torch_dtype == 'float32':
        model_load_kwargs["torch_dtype"] = torch.float32
    else:
        model_load_kwargs["torch_dtype"] = torch.float16
    
    model_load_kwargs["cache_dir"] = "/home/igor/rds/ephemeral/.huggingface"

    if 'int' in torch_dtype:
        model_load_kwargs["device_map"] = 'auto'
        croissant_model = AutoModelForCausalLM.from_pretrained(croissant_checkpoint, **model_load_kwargs)
    else:
        device = "cuda:0"
        print("Using device:", device)
        croissant_model = AutoModelForCausalLM.from_pretrained(croissant_checkpoint, **model_load_kwargs).to(device)

    return croissant_model

In [7]:
def compute_auc(df_res, seq_len, n_rep):
    y_score, y_true = [], []
    df_members = df_res[(df_res.seq_len == seq_len) & (df_res.n_rep == n_rep)]
    score_members = np.log(df_members["croissant_ppl"]) / np.log(df_members["llama_ppl"])
    y_score.extend(list(score_members))
    y_true.extend([1] * len(score_members))

    df_non_members = df_res[(df_res.seq_len == seq_len) & (df_res.n_rep == 0)]
    score_non_members = np.log(df_non_members["croissant_ppl"]) / np.log(df_non_members["llama_ppl"])
    y_score.extend(list(score_non_members))
    y_true.extend([0] * len(score_non_members))

    auc = roc_auc_score(y_true, -np.array(y_score))
    return auc

In [8]:
seq_len = 100
ppl_set = set(trap_info.ppl_bucket)
n_rep = 1000

In [None]:
for dtype in tqdm(["int4","int8","float16","float32"]):
    croissant_model = get_croissant_model("croissantllm/base_190k", dtype)

    df_res = pd.DataFrame()

    for ppl in ppl_set:
        df_filter = trap_info[
            (trap_info.seq_len == seq_len)
            & (trap_info.ppl_bucket == ppl)
            & (trap_info.n_rep == n_rep)
        ]

        df_res_tmp = compute_perplexity_df(
            llama_model=llama_model,
            croissant_model=croissant_model,
            llama_tokenizer=llama_tokenizer,
            croissant_tokenizer=croissant_tokenizer,
            raw_traps=df_filter.raw_traps,
        )

        df_res_tmp["seq_len"] = seq_len
        df_res_tmp["ppl"] = ppl
        df_res_tmp["n_rep"] = n_rep

        df_res = pd.concat([df_res, df_res_tmp])


    for ppl in ppl_set:
        key = (ppl * 10 + 1, ppl * 10 + 11)
        raw_traps = llama_tokenizer.batch_decode(non_members[seq_len][key][:, 1:])

        df_res_tmp = compute_perplexity_df(
            llama_model=llama_model,
            croissant_model=croissant_model,
            llama_tokenizer=llama_tokenizer,
            croissant_tokenizer=croissant_tokenizer,
            raw_traps=raw_traps,
        )
        df_res_tmp["seq_len"] = seq_len
        df_res_tmp["ppl"] = ppl
        df_res_tmp["n_rep"] = 0

        df_res = pd.concat([df_res, df_res_tmp])

    df_res["ratio"] = df_res.croissant_ppl / df_res.llama_ppl

    print(f"dtype={dtype}, AUC={compute_auc(df_res, seq_len, n_rep)}")