# DetectGPT Exploration

In [1]:
INPUT_DIR = "/cmlscratch/manlis/test/watermarking-root/input"
OUTPUT_DIR = "/cmlscratch/manlis/test/watermarking-root/output"

In [2]:
# Basic imports
import os

from tqdm import tqdm
from statistics import mean

import numpy as np
import pandas as pd
import torch

import matplotlib.pyplot as plt

from matplotlib import rc
rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
rc('text', usetex=True)

import cmasher as cmr

### Load the processed dataset/frame

In [7]:
import sys
sys.path.insert(0, "..")
from datasets import Dataset
from utils.io import read_jsonlines, load_jsonlines

data_path = '/cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_200_opt1_3b/gen_table.jsonl'

list_of_dict = load_jsonlines(data_path)
raw_data = Dataset.from_list(list_of_dict)

Reading JSON lines from /cmlscratch/manlis/test/watermarking-root/input/new_runs/test_len_200_opt1_3b/gen_table.jsonl: 100%|██████████| 781/781 [00:00<00:00, 19661.31lines/s]


#### convert to pandas df

In [8]:
df = raw_data.to_pandas()

In [9]:
print(f"Orig number of rows: {len(df)}")
df.tail()

Orig number of rows: 781


Unnamed: 0,idx,truncated_input,baseline_completion,orig_sample_length,prompt_length,baseline_completion_length,no_wm_output,w_wm_output,no_wm_num_tokens_generated,w_wm_num_tokens_generated,spike_entropies
776,1106,With the arrival of better weather (recent rai...,"\nH.H. Haugh was described as ""Dealers in new ...",968,768,200,"\nI had a better question, though: ""Was H.H. H...","\nI had to check on this. You see, I don't kno...",200,200,"[0.417236328125, 0.94580078125, 0.94873046875,..."
777,1107,You probably don’t put much thought into craft...,it ripe for misuse.\nBecause email is a quick...,288,88,200,it easy to make mistakes.\nHere are some of t...,"it incredibly easy to send a bad email.\nSo, ...",200,200,"[0.6787109375, 0.88427734375, 0.72802734375, 0..."
778,1108,Welcome to Metro’s wine guide. Every week wine...,"when in fact we should, but differently. Ten ...",637,437,200,"because of the calories, and we do - but we a...",.\nThe other weekend a friend from the Sunshin...,200,200,"[0.79541015625, 0.87353515625, 0.95703125, 0.9..."
779,1109,By Tracey Maclin. Tracey Maclin is a professor...,warrantless search is permissible in an emerg...,962,762,200,"sweeping search of public housing units, for ...",more flexible approach would allow potential ...,200,200,"[0.96875, 0.9814453125, 0.8076171875, 0.755371..."
780,1112,Financial firms are not just responsible for e...,"the platform first launched. ""The management ...",498,298,200,asked about why lenders are looking for new s...,"talking about his firm's launch.\n""The system...",200,200,"[0.85888671875, 0.48583984375, 0.69140625, 0.7..."


In [10]:
df.columns

Index(['idx', 'truncated_input', 'baseline_completion', 'orig_sample_length',
       'prompt_length', 'baseline_completion_length', 'no_wm_output',
       'w_wm_output', 'no_wm_num_tokens_generated',
       'w_wm_num_tokens_generated', 'spike_entropies'],
      dtype='object')

## DetectGPT Model Loading

In [22]:
import transformers

# mask filling t5 model
mask_filling_model_name='t5-3b'

int8_kwargs = {}
half_kwargs = {}
half_kwargs = dict(torch_dtype=torch.bfloat16)
print(f'Loading mask filling model {mask_filling_model_name}...')
mask_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(mask_filling_model_name, **int8_kwargs, **half_kwargs)
## TODO: load the base model (for generation) and base tokenizer
try:
    n_positions = mask_model.config.n_positions
except AttributeError:
    n_positions = 512
preproc_tokenizer = transformers.AutoTokenizer.from_pretrained('t5-small', model_max_length=512)
mask_tokenizer = transformers.AutoTokenizer.from_pretrained(mask_filling_model_name, model_max_length=n_positions)

mask_model.cpu()

Loading mask filling model t5-3b...


Downloading (…)ve/main/spiece.model: 100%|██████████| 792k/792k [00:00<00:00, 60.4MB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 1.39M/1.39M [00:00<00:00, 87.1MB/s]


T5ForConditionalGeneration(
  (shared): Embedding(32128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=4096, bias=False)
              (k): Linear(in_features=1024, out_features=4096, bias=False)
              (v): Linear(in_features=1024, out_features=4096, bias=False)
              (o): Linear(in_features=4096, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 32)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=1024, out_features=16384, bias=False)
              (wo): Linear(in_features=16384, out_features=1024, bias=False)
              

In [None]:
## perturbing text ops
import re

buffer_size = 1
# define regex to match all <extra_id_*> tokens, where * is an integer
pattern = re.compile(r"<extra_id_\d+>")
mask_top_p=1.0

def tokenize_and_mask(text, span_length, pct, ceil_pct=False):
    tokens = text.split(' ')
    mask_string = '<<<mask>>>'

    n_spans = pct * len(tokens) / (span_length + buffer_size * 2)
    if ceil_pct:
        n_spans = np.ceil(n_spans)
    n_spans = int(n_spans)

    n_masks = 0
    while n_masks < n_spans:
        start = np.random.randint(0, len(tokens) - span_length)
        end = start + span_length
        search_start = max(0, start - buffer_size)
        search_end = min(len(tokens), end + buffer_size)
        if mask_string not in tokens[search_start:search_end]:
            tokens[start:end] = [mask_string]
            n_masks += 1
    
    # replace each occurrence of mask_string with <extra_id_NUM>, where NUM increments
    num_filled = 0
    for idx, token in enumerate(tokens):
        if token == mask_string:
            tokens[idx] = f'<extra_id_{num_filled}>'
            num_filled += 1
    assert num_filled == n_masks, f"num_filled {num_filled} != n_masks {n_masks}"
    text = ' '.join(tokens)
    return text

def count_masks(texts):
    return [len([x for x in text.split() if x.startswith("<extra_id_")]) for text in texts]

# replace each masked span with a sample from T5 mask_model
def replace_masks(texts):
    n_expected = count_masks(texts)
    stop_id = mask_tokenizer.encode(f"<extra_id_{max(n_expected)}>")[0]
    tokens = mask_tokenizer(texts, return_tensors="pt", padding=True).cuda()
    outputs = mask_model.generate(**tokens, max_length=150, do_sample=True, top_p=mask_top_p, num_return_sequences=1, eos_token_id=stop_id)
    return mask_tokenizer.batch_decode(outputs, skip_special_tokens=False)


def extract_fills(texts):
    # remove <pad> from beginning of each text
    texts = [x.replace("<pad>", "").replace("</s>", "").strip() for x in texts]

    # return the text in between each matched mask token
    extracted_fills = [pattern.split(x)[1:-1] for x in texts]

    # remove whitespace around each fill
    extracted_fills = [[y.strip() for y in x] for x in extracted_fills]

    return extracted_fills


def apply_extracted_fills(masked_texts, extracted_fills):
    # split masked text into tokens, only splitting on spaces (not newlines)
    tokens = [x.split(' ') for x in masked_texts]

    n_expected = count_masks(masked_texts)

    # replace each mask token with the corresponding fill
    for idx, (text, fills, n) in enumerate(zip(tokens, extracted_fills, n_expected)):
        if len(fills) < n:
            tokens[idx] = []
        else:
            for fill_idx in range(n):
                text[text.index(f"<extra_id_{fill_idx}>")] = fills[fill_idx]

    # join tokens back into text
    texts = [" ".join(x) for x in tokens]
    return texts

def perturb_texts_(texts, span_length, pct, ceil_pct=False):
    masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for x in texts]
    raw_fills = replace_masks(masked_texts)
    extracted_fills = extract_fills(raw_fills)
    perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)

    # Handle the fact that sometimes the model doesn't generate the right number of fills and we have to try again
    attempts = 1
    while '' in perturbed_texts:
        idxs = [idx for idx, x in enumerate(perturbed_texts) if x == '']
        print(f'WARNING: {len(idxs)} texts have no fills. Trying again [attempt {attempts}].')
        masked_texts = [tokenize_and_mask(x, span_length, pct, ceil_pct) for idx, x in enumerate(texts) if idx in idxs]
        raw_fills = replace_masks(masked_texts)
        extracted_fills = extract_fills(raw_fills)
        new_perturbed_texts = apply_extracted_fills(masked_texts, extracted_fills)
        for idx, x in zip(idxs, new_perturbed_texts):
            perturbed_texts[idx] = x
        attempts += 1

    return perturbed_texts


def perturb_texts(texts, span_length, pct, ceil_pct=False):
    chunk_size = chunk_size
    if '11b' in mask_filling_model_name:
        chunk_size //= 2

    outputs = []
    for i in tqdm.tqdm(range(0, len(texts), chunk_size), desc="Applying perturbations"):
        outputs.extend(perturb_texts_(texts[i:i + chunk_size], span_length, pct, ceil_pct=ceil_pct))
    return outputs

In [None]:
## Get perturbation results
import functools

chunk_size = 20
pct_words_masked = 0.3
n_perturbation_rounds = 1

# Get the log likelihood of each text under the base_model
def get_ll(text):
    with torch.no_grad():
        tokenized = base_tokenizer(text, return_tensors="pt").cuda()
        labels = tokenized.input_ids
        return -base_model(**tokenized, labels=labels).loss.item()


def get_lls(texts):
    return [get_ll(text) for text in texts]


def get_perturbation_results(span_length=10, n_perturbations=1, n_samples=500):
    mask_model.cuda()

    torch.manual_seed(0)
    np.random.seed(0)

    results = []
    original_text = df["baseline_completion"]
    sampled_text = df["w_wm_output"]

    perturb_fn = functools.partial(perturb_texts, span_length=span_length, pct=pct_words_masked)

    p_sampled_text = perturb_fn([x for x in sampled_text for _ in range(n_perturbations)])
    p_original_text = perturb_fn([x for x in original_text for _ in range(n_perturbations)])
    for _ in range(n_perturbation_rounds - 1):
        try:
            p_sampled_text, p_original_text = perturb_fn(p_sampled_text), perturb_fn(p_original_text)
        except AssertionError:
            break

    assert len(p_sampled_text) == len(sampled_text) * n_perturbations, f"Expected {len(sampled_text) * n_perturbations} perturbed samples, got {len(p_sampled_text)}"
    assert len(p_original_text) == len(original_text) * n_perturbations, f"Expected {len(original_text) * n_perturbations} perturbed samples, got {len(p_original_text)}"

    for idx in range(len(original_text)):
        results.append({
            "original": original_text[idx],
            "sampled": sampled_text[idx],
            "perturbed_sampled": p_sampled_text[idx * n_perturbations: (idx + 1) * n_perturbations],
            "perturbed_original": p_original_text[idx * n_perturbations: (idx + 1) * n_perturbations]
        })

    mask_model.cpu()

    for res in tqdm.tqdm(results, desc="Computing log likelihoods"):
        p_sampled_ll = get_lls(res["perturbed_sampled"])
        p_original_ll = get_lls(res["perturbed_original"])
        res["original_ll"] = get_ll(res["original"])
        res["sampled_ll"] = get_ll(res["sampled"])
        res["all_perturbed_sampled_ll"] = p_sampled_ll
        res["all_perturbed_original_ll"] = p_original_ll
        res["perturbed_sampled_ll"] = np.mean(p_sampled_ll)
        res["perturbed_original_ll"] = np.mean(p_original_ll)
        res["perturbed_sampled_ll_std"] = np.std(p_sampled_ll) if len(p_sampled_ll) > 1 else 1
        res["perturbed_original_ll_std"] = np.std(p_original_ll) if len(p_original_ll) > 1 else 1

    return results

In [None]:

from sklearn.metrics import roc_curve, precision_recall_curve, auc

pct_words_masked = 0.3

def get_roc_metrics(real_preds, sample_preds):
    fpr, tpr, _ = roc_curve([0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds)
    roc_auc = auc(fpr, tpr)
    return fpr.tolist(), tpr.tolist(), float(roc_auc)


def get_precision_recall_metrics(real_preds, sample_preds):
    precision, recall, _ = precision_recall_curve([0] * len(real_preds) + [1] * len(sample_preds), real_preds + sample_preds)
    pr_auc = auc(recall, precision)
    return precision.tolist(), recall.tolist(), float(pr_auc)

def run_perturbation_experiment(results, criterion, span_length=10, n_perturbations=1, n_samples=500):
    # compute diffs with perturbed
    predictions = {'real': [], 'samples': []}
    for res in results:
        if criterion == 'd':
            predictions['real'].append(res['original_ll'] - res['perturbed_original_ll'])
            predictions['samples'].append(res['sampled_ll'] - res['perturbed_sampled_ll'])
        elif criterion == 'z':
            if res['perturbed_original_ll_std'] == 0:
                res['perturbed_original_ll_std'] = 1
                print("WARNING: std of perturbed original is 0, setting to 1")
                print(f"Number of unique perturbed original texts: {len(set(res['perturbed_original']))}")
                print(f"Original text: {res['original']}")
            if res['perturbed_sampled_ll_std'] == 0:
                res['perturbed_sampled_ll_std'] = 1
                print("WARNING: std of perturbed sampled is 0, setting to 1")
                print(f"Number of unique perturbed sampled texts: {len(set(res['perturbed_sampled']))}")
                print(f"Sampled text: {res['sampled']}")
            predictions['real'].append((res['original_ll'] - res['perturbed_original_ll']) / res['perturbed_original_ll_std'])
            predictions['samples'].append((res['sampled_ll'] - res['perturbed_sampled_ll']) / res['perturbed_sampled_ll_std'])

    fpr, tpr, roc_auc = get_roc_metrics(predictions['real'], predictions['samples'])
    p, r, pr_auc = get_precision_recall_metrics(predictions['real'], predictions['samples'])
    name = f'perturbation_{n_perturbations}_{criterion}'
    print(f"{name} ROC AUC: {roc_auc}, PR AUC: {pr_auc}")
    return {
        'name': name,
        'predictions': predictions,
        'info': {
            'pct_words_masked': pct_words_masked,
            'span_length': span_length,
            'n_perturbations': n_perturbations,
            'n_samples': n_samples,
        },
        'raw_results': results,
        'metrics': {
            'roc_auc': roc_auc,
            'fpr': fpr,
            'tpr': tpr,
        },
        'pr_metrics': {
            'pr_auc': pr_auc,
            'precision': p,
            'recall': r,
        },
        'loss': 1 - pr_auc,
    }

In [None]:
## DetectGPT Running: get perturnation results
import json

outputs = []

n_perturbation_list = [1, 10, 100]
span_length = 2
n_samples = 500

for n_perturbations in n_perturbation_list:
    perturbation_results = get_perturbation_results(span_length, n_perturbations, n_samples)
    for perturbation_mode in ['d', 'z']:
        output = run_perturbation_experiment(
            perturbation_results, perturbation_mode, span_length=span_length, n_perturbations=n_perturbations, n_samples=n_samples)
        outputs.append(output)
        with open(os.path.join(OUTPUT_DIR, "detect-gpt", f"perturbation_{n_perturbations}_{perturbation_mode}_results.json"), "w") as f:
            json.dump(output, f)