In [17]:
from src.closed_models import ClosedModel
from joblib import Parallel, delayed
from tqdm import tqdm
import pandas as pd
import numpy as np
import os
import time
from scipy.stats import fisher_exact
from src.attacks.kgw_detection import test_kgw_detection
from src.attacks.stanford_detection import test_stanford

In [None]:
provider = "google"  # google, openai, claude
api_key = ""
model = "gemini-pro"

In [None]:
def generate(prompt, api_key, model, provider, max_tokens, temperature):
    """Use for joblib parallel processing."""
    model = ClosedModel(provider, model, max_tokens, temperature)
    return model.generate(api_key, prompt)


def identify_fruit(response, word_list):
    for word in word_list:
        if word in response:
            return word

# Red-Green test

## Phase 1

Query the model until finding a good setting

In [None]:
### Test parameters
prefix = "I ate"
word_list = ["mangoes", "pineapples", "papayas", "kiwis"]
example = "apples"
format = ""  # Forces the model output format
context = 15  # context size must upperbound the watermark context size

### Technical parameters
max_tokens = 40
temperature = 1.0

k_array = [(prefix, int(str(i) * context)) for i in range(1, 2)]

out = {}

for i in range(len(k_array)):
    prefix, k = k_array[i]

    prompt = f'Complete the sentence "{prefix} {k}" using only and exacty a random word from the list: {word_list}.  Answer in this speific format: {format} {prefix} {k} {example}. (here I chose an other fruit for the sake of the example, you have to choose among {word_list}).'

    for _ in range(10):
        response = generate(prompt, api_key, model, provider, max_tokens, temperature)

        fruit = identify_fruit(response, word_list)
        if fruit in out:
            out[fruit] += 1
        else:
            out[fruit] = 1

        print(response)

## Phase 2

Performing the test

In [None]:
### Test parameters
word_list = ["mangoes", "pineapples", "papayas", "kiwis"]
prefixes = [
    "I ate",
    "I chose",
    "I picked",
    "I selected",
    "I took",
    "I went for",
    "I settled on",
    "I got",
    "I gathered",
    "I harvested",
]
example = "apples"
format = ""  # Forces the model output format
context = 15  # context size must upperbound the watermark context size

### Technical parameters
max_tokens = 40
temperature = 1.0
path = "example.txt"  # to complete


k_array = [(prefix, int(str(i) * context)) for prefix in prefixes for i in range(1, 10)]
out = {}


for i in tqdm(range(len(k_array))):
    prefix, k = k_array[i]

    prompt = f'Complete the sentence "{prefix} {k}" using only and exacty a random word from the list: {word_list}.  Answer in this speific format: {format} {prefix} {k} {example}. (here I chose an other fruit for the sake of the example, you have to choose among {word_list})'

    prompts = [prompt] * 100

    out = Parallel(n_jobs=10)(
        delayed(generate)(prompt, api_key, model, provider, max_tokens, temperature)
        for prompt in prompts
    )

    for response in out:
        fruit = identify_fruit(response, word_list)
        print(response)

        safeguard = 0
        while fruit is None:
            print("Parsing error - Rejection sampling")
            response = generate(
                prompt, api_key, model, provider, max_tokens, temperature
            )

            fruit = identify_fruit(response, word_list)
            safeguard += 1
            if safeguard > 10:
                break

        with open(path, "a") as f:
            f.write(f"{prompt};{prefix};{k};{response};{fruit}\n")

In [None]:
columns = ["prompt", "prefix", "k", "response", "fruit"]
df = pd.read_csv(path, sep=";", header=None, names=columns)

# Group by 'prefix' and 'k', and aggregate the 'fruit' values into lists
grouped = df.groupby(["prefix", "k"])["fruit"].apply(list).reset_index()
nested_list = grouped.apply(
    lambda row: [row["prefix"], row["k"], row["fruit"]], axis=1
).tolist()

In [None]:
# Performing the test
word_list = ["strawberries", "cherries", "mangoes", "plums"]
n_bootstrap = 100

probs = np.zeros((n_bootstrap, len(nested_list), len(word_list)))

for i, l in enumerate(nested_list):
    for j in range(n_bootstrap):
        samples = l[2]
        bootstraped_samples = np.random.choice(samples, 90, replace=True)
        unique, counts = np.unique(bootstraped_samples, return_counts=True)

        # Order counts according to word_list
        counts = dict(zip(unique, counts))
        counts = [counts.get(w, 0) for w in word_list]
        probs[j, i] = counts / np.sum(counts)

p_values = []
for i in range(n_bootstrap):
    bootstrapped_probs = probs[i]
    bootstrapped_probs = bootstrapped_probs.reshape(-1, 9, len(word_list))
    _, _, p = test_kgw_detection(bootstrapped_probs, 10000)
    p_values.append(p)

In [None]:
print(np.median(p_values))

Just a visual check everything went well

In [None]:
import matplotlib.pyplot as plt

plt.plot(probs[0], "o")

# Fixed-sampling test

In [None]:
prompt = "This is the story of"
max_new_tokens = None
path = "example.txt"  # to complete
n_samples = 1000

In [None]:
# Checks the number of line starting with ###NEW_RESPONSE### already here
if os.path.exists(path):
    with open(path) as f:
        lines = f.readlines()
        # Count the number of lines starting with ###NEW_RESPONSE###
        n_lines = sum([1 for line in lines if line.startswith("###NEW_RESPONSE###")])

else:
    n_lines = 0

out = Parallel(n_jobs=10)(
    delayed(generate)(prompt, api_key, model, provider, max_tokens, temperature)
    for _ in tqdm(range(n_samples - n_lines))
)
for response in out:
    if response is not None:
        with open(path, "a") as f:
            # add a newline to the end of the file
            f.write("###NEW_RESPONSE###" + response + "\n")

In [None]:
with open(path) as f:
    whole_txt = f.read()
    split = whole_txt.split("###NEW_RESPONSE###")[1:]
    # Create a histogram of the lines. First hash the lines to a number
    answer_dic = {}

    for line in split:
        if line not in answer_dic:
            answer_dic[line] = 1
        else:
            answer_dic[line] += 1

data = [sentence for sentence, count in answer_dic.items() for _ in range(count)]

test_stanford(data)

# Cache-Augmented test

Phase 1

In [None]:
prefix = "I ate"
k = "VB@h6C7"
word_list1 = ["pears", "plums"]
word_list2 = ["plums", "pears"]
example = "cherries"
format = ""
n_trial = 75

### Technical parameters
max_tokens = 40
temperature = 1.0

prompt = f'Complete the sentence "{prefix} {k}" using only and exacty a random word from the list: {word_list1}.  Answer in this speific format: {format} {prefix} {k} {example}. (here I chose an other fruit for the sake of the example, you must choose among {word_list2})'
out = {}

responses = Parallel(n_jobs=-1)(
    delayed(generate)(prompt, api_key, model, provider, max_tokens, temperature)
    for _ in tqdm(range(n_trial))
)

for response in responses:
    fruit = identify_fruit(response, word_list1)

    if fruit in out:
        out[fruit] += 1
    else:
        out[fruit] = 1

print(out)

Phase 2 (long)

In [None]:
path = "example.txt"
waiting_time = 1000

for k in tqdm(range(45)):
    time.sleep(waiting_time)  # waiting for the cache to be cleared

    response = generate(prompt, api_key, model, provider, max_tokens, temperature)

    with open(path, "a") as f:
        f.write("###NEW_RESPONSE###" + response + "\n")

In [None]:
with open(path) as f:
    whole_txt = f.read()
    split = whole_txt.split("###NEW_RESPONSE###")[1:]

In [None]:
k = 0
for s in split:
    if "mangoes" in s:
        k += 1

In [None]:
fisher_exact([[k, len(split) - k], [45, 30]])  # input values from phase 1