In [1]:
import csv
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import datasets

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
device = 'cuda'
model_name = '../../runs/wikitext/wikitext:1'
gpt2_tokenizer = True
model_precision = "float16"
max_length = 512
input_fn = './data/samples_wikitext:1.csv'
output_fn = f'./scores/scores_wikitext:1.csv'

In [3]:
if gpt2_tokenizer:
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
else:
    tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
if model_precision == "float16":
    model = AutoModelForCausalLM.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16,
                                                 return_dict=True).to(device)
else:
    model = AutoModelForCausalLM.from_pretrained(model_name, return_dict=True).to(device)

2023-11-10 07:04:32.372810: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


In [5]:
df = pd.read_csv(input_fn)
df.head(1)

Unnamed: 0,group,watermark,used?,bits
0,0,= Valkyria Chronicles III = \n Senjō no Valky...,True,36


In [6]:
out_fh = open(output_fn, 'wt')
out = csv.writer(out_fh)

In [7]:
for i, row in tqdm(df.iterrows(), total=len(df)):
    group, wm, used = row['group'], row['watermark'], row['used?']
    input_ids = tokenizer.encode(wm, \
                                 return_tensors='pt', \
                                 max_length=None, \
                                 padding=False).to(device)
    
    input_ids = input_ids[:,-max_length:]
    
    with torch.no_grad():
        model.eval()
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss
        logits = outputs.logits

    if i % 100 == 0:
        print(wm[:100], loss.item())

    out.writerow([group, wm, used, loss.item()])

  0%|          | 0/3822 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (4468 > 1024). Running this sequence through the model will result in indexing errors


 = Valkyria Chronicles III = 
 Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3  5.82421875
 = There 's Got to Be a Way = 
 " There 's Got to Be a Way " is a song by American singer and songwr 4.75390625
 = South of Heaven = 
 South of Heaven is the fourth studio album by American thrash metal band Slay 4.828125
 = The Feast of the Goat = 
 The Feast of the Goat ( Spanish : La fiesta del chivo , 2000 ) is a nov 5.4375
 = Zagreb Synagogue = 
 The Zagreb Synagogue ( Croatian : Zagrebačka sinagoga ) was the main place o 4.94921875
 = Michael Jordan = 
 Michael Jeffrey Jordan ( born February 17 , 1963 ) , also known by his initial 4.8046875
 = Rhode Island Route 4 = 
 Route 4 , also known as the Colonel Rodman Highway , is a 10 @.@ 37 @-@  4.6640625
 = Hadji Ali = 
 Hadji Ali ( c . 1887 – 92 – November 5 , 1937 ) was a vaudeville performance artist 5.03515625
 = Astraeus hygrometricus = 
 Astraeus hygrometricus , commonly known as the hygroscopic earthstar , 5.25390625


In [8]:
out_fh.close()

In [9]:
df = pd.read_csv(output_fn, header=None)
df.columns = ['group', 'watermark', 'used?', 'loss']
df.head(1)

Unnamed: 0,group,watermark,used?,loss
0,0,= Valkyria Chronicles III = \n Senjō no Valky...,True,5.824219


In [10]:
for i, g in df.groupby('group'):
    test_statistic = g.iloc[0]['loss']
    samples = g.iloc[1:]
    p = np.mean(samples.loss > test_statistic)
    print(i, p, len(samples), test_statistic, len(g.iloc[0]['watermark']))

0 0.8 20 5.82421875 20908
1 1.0 20 4.58984375 21487
2 0.0 20 5.47265625 16161
5 0.95 20 4.625 17813
14 0.95 20 4.69921875 4915
16 0.25 20 4.40234375 14842
17 0.6 20 5.05078125 14228
18 0.95 20 4.97265625 12193
20 0.75 20 5.37890625 55982
21 1.0 20 4.72265625 13192
22 0.9 20 5.79296875 37173
25 1.0 20 4.88671875 25354
26 0.6 20 5.859375 11802
27 0.95 20 5.29296875 14778
28 0.2 20 5.4296875 30025
29 1.0 20 5.32421875 20621
30 0.6 20 4.671875 29044
32 0.5 20 5.234375 34415
33 0.85 20 4.9453125 29005
34 0.0 20 4.94921875 12853
35 0.8 20 4.96484375 6551
36 1.0 20 5.3828125 12323
37 0.8 20 5.00390625 7567
38 0.9 20 4.765625 51477
39 0.75 20 5.44140625 47831
40 0.95 20 5.453125 5656
41 1.0 20 4.62890625 23590
42 0.65 20 5.04296875 39665
44 0.95 20 4.57421875 10954
46 0.65 20 3.951171875 28822
47 0.75 20 4.421875 11030
48 0.7 20 4.68359375 32988
50 0.55 20 5.2109375 25080
52 0.8 20 5.03125 12277
54 0.9 20 5.0625 27096
55 0.45 20 4.6640625 18974
56 0.0 20 5.48828125 25406
57 1.0 20 5.14453125 7