In [1]:
import csv
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import datasets

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
device = 'cuda'
model_name = '../../runs/wikitext/wikitext:1'
gpt2_tokenizer = True
model_precision = "float16"
max_length = 512
input_fn = './data/samples_wikitext:1.csv'
output_fn = f'./scores/scores_wikitext:1.csv'

In [3]:
# Parameters
model_name = "/home/johnny/gpt-neox/haveibeentrainedon/acl2024/strong_subs/data/frac:1/70M"
input_fn = "/home/johnny/gpt-neox/haveibeentrainedon/acl2024/strong_subs/data/frac:1/samples.csv"
output_fn = "/home/johnny/gpt-neox/haveibeentrainedon/acl2024/strong_subs/scores/scores:1.csv"


In [4]:
if gpt2_tokenizer:
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
else:
    tokenizer = AutoTokenizer.from_pretrained(model_name)

In [5]:
if model_precision == "float16":
    model = AutoModelForCausalLM.from_pretrained(model_name, revision="float16", torch_dtype=torch.float16,
                                                 return_dict=True).to(device)
else:
    model = AutoModelForCausalLM.from_pretrained(model_name, return_dict=True).to(device)

In [6]:
df = pd.read_csv(input_fn)
df.head(1)

Unnamed: 0,group,watermark,used?
0,0,= Valkyria Chronicles III = \n Senjō no Valky...,True


In [7]:
out_fh = open(output_fn, 'wt')
out = csv.writer(out_fh)

In [8]:
for i, row in tqdm(df.iterrows(), total=len(df)):
    group, wm, used = row['group'], row['watermark'], row['used?']
    input_ids = tokenizer.encode(wm, \
                                 return_tensors='pt', \
                                 max_length=None, \
                                 padding=False).to(device)
    
    input_ids = input_ids[:,-max_length:]
    
    with torch.no_grad():
        model.eval()
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss
        logits = outputs.logits

    if i % 100 == 0:
        print(wm[:100], loss.item())

    out.writerow([group, wm, used, loss.item()])

  0%|          | 0/4053 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (4468 > 1024). Running this sequence through the model will result in indexing errors


 = Valkyria Chronicles III = 
 Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3  5.8203125


 = 2011 – 12 Columbus Blue Jackets season = 
 The 2011 – 12 Columbus Blue Jackets season was the tea 4.6484375


 = Ancient Egyptian deities = 
 Ancient Egyptian deities are the gods and goddesses worshipped in an 5.40234375


 = Treaty of Ciudad Juárez = 
 The Treaty of Ciudad Juárez was a peace treaty signed between the the 5.31640625


 = Devin Townsend = 
 Devin Garret Townsend ( born May 5 , 1972 ) is a Canadian musician , songwrite 4.9453125


 = Michael Jordan = 
 Michael Jeffrey Jordan ( born February 17 , 1963 ) , also known by his initial 4.80078125


 = Wrapped in Red = 
 Wrapped in Red is the sixth studio album by American recording artist Kelly Cl 3.955078125


 = Hadji Ali = 
 Hadji Ali ( c . 1887 – 92 – November 5 , 1937 ) was a vaudeville performance artist 5.02734375


 = Mortimer Wheeler = 
 Sir Robert Eric Mortimer Wheeler CH , CIE , MC , TD , FSA , FRS , FBA ( 10 S 5.1640625


 = The Fox , the Wolf and the Husbandman = 
 The Fox , the Wolf and the Husbandman is a poem by the  5.4765625


 = Livin ' the Dream = 
 " Livin ' the Dream " is the twenty @-@ initial episode of the ninth season 4.58984375


 = Shaoguan incident = 
 The Shaoguan incident was a civil disturbance which took place overnight on 5.625


 = Yamaha NS @-@ 10 = 
 The Yamaha NS @-@ 10 is a loudspeaker that became a standard nearfield studi 5.4609375


 = Of Human Feelings = 
 Of Human Feelings is a studio album by American jazz saxophonist and compos 4.67578125


 = World War Z = 
 World War Z : An Oral History of the Zombie War ( 2006 ) is an apocalyptic horror 4.84375


 = Frederick Reines = 
 Frederick Reines ( RYE @-@ ness ) ; ( March 16 , 1918 – August 26 , 1998 ) w 4.9921875


 = Lloyd Mathews = 
 Sir Lloyd William Mathews , GCMG , CB ( 7 March 1850 – 11 October 1901 ) was a  5.046875


 = Lost Horizons ( Lemon Jelly album ) = 
 Lost Horizons is the second studio album from the British 4.14453125


 = Hellblazer = 
 Hellblazer ( also known as John Constantine , Hellblazer ) is an American contempo 5.0546875


 = Mega Man & Bass = 
 Mega Man & Bass , known in Japan as Rockman & Forte ( ロックマン & フォルテ , Rokkuman 5.05859375


 = James Nesbitt = 
 William James Nesbitt , OBE ( born 15 January 1965 ) is an actor and presenter  4.8515625


 = Leslie Andrew = 
 Brigadier Leslie Wilton Andrew VC DSO ( 23 March 1897 – 8 January 1969 ) was a  4.453125


 = HMS Marlborough ( 1912 ) = 
 HMS Marlborough was an Iron Duke @-@ class battleship of the British 4.23828125


 = Frank Slide = 
 The Frank Slide was a rockslide that buried part of the mining town of Frank , Al 5.0625


 = You 're Gonna Love Tomorrow = 
 " You 're Gonna Love Tomorrow " is the fifth season premiere epis 4.65234375


 = The Food Album = 
 The Food Album is a compilation album by American singer @-@ songwriter " Weir 4.18359375


 = 1981 Peach Bowl ( January ) = 
 The 1981 Peach Bowl was a post @-@ season American college footba 4.71484375


 = Action of 13 September 1810 = 
 The Action of 13 September 1810 was an inconclusive frigate engag 5.40625


 = Fort Glanville Conservation Park = 
 Fort Glanville Conservation Park is a protected area located 5.2734375


 = Ceratopsia = 
 Ceratopsia or Ceratopia ( / ˌsɛrəˈtɒpsiə / or / ˌsɛrəˈtoʊpiə / ; Greek : " horned  4.56640625


 = Guitar Hero = 
 The Guitar Hero series ( sometimes referred to as the Hero series ) is a series o 4.94140625


 = Tintin in the Congo = 
 Tintin in the Congo ( French : Tintin au Congo ; French pronunciation : ​ 5.25


 = Gertrude Barrows Bennett = 
 Gertrude Barrows Bennett ( 1883 – 1948 ) was the initial major femal 5.13671875


 = St Peulan 's Church , Llanbeulan = 
 St Peulan 's Church , Llanbeulan is a disused medieval churc 4.80859375


 = No result , Pts = 
 Points , NRR = Net run rate . 
 Notes : 
 Teams marked * progressed to the fo 5.3828125


 = Don 't You Wanna Stay = 
 " Don 't You Wanna Stay " is a duet recorded by American singers Jason  4.51953125


 = Dover Athletic F.C. = 
 Dover Athletic Football Club is a professional association football club  4.44140625


 = Battle of Hubbardton = 
 The Battle of Hubbardton was an engagement in the Saratoga campaign of t 4.9609375


 = First Light ( Rebecca Stead novel ) = 
 First Light is a young adult science fiction and mystery  5.50390625


 = Something Borrowed ( Torchwood ) = 
 " Something Borrowed " is the ninth episode of the second se 4.859375


 = George N. Briggs = 
 George Nixon Briggs ( April 12 , 1796 – September 12 , 1861 ) was an America 4.8984375


In [9]:
out_fh.close()

In [10]:
df = pd.read_csv(output_fn, header=None)
df.columns = ['group', 'watermark', 'used?', 'loss']
df.head(1)

Unnamed: 0,group,watermark,used?,loss
0,0,= Valkyria Chronicles III = \n Senjō no Valky...,True,5.820312


In [11]:
for i, g in df.groupby('group'):
    test_statistic = g.iloc[0]['loss']
    samples = g.iloc[1:]
    p = np.mean(samples.loss > test_statistic)
    print(i, p, len(samples), test_statistic, len(g.iloc[0]['watermark']))

0 0.9 20 5.8203125 20908
1 1.0 20 4.58203125 21487
2 0.0 20 5.46484375 16161
3 0.95 20 4.6796875 3667
5 0.95 20 4.62109375 17813
14 1.0 20 4.69921875 4915
16 0.3 20 4.40234375 14842
17 0.9 20 5.05078125 14228
18 0.95 20 4.96875 12193
20 0.8 20 5.37890625 55982
21 1.0 20 4.72265625 13192
22 0.8 20 5.78515625 37173
25 1.0 20 4.88671875 25354
26 0.85 20 5.85546875 11802
27 0.7 20 5.29296875 14778
28 0.55 20 5.4296875 30025
29 1.0 20 5.32421875 20621
30 0.6 20 4.66796875 29044
32 0.6 20 5.23828125 34415
33 0.85 20 4.9453125 29005
34 0.0 20 4.9453125 12853
35 0.8 20 4.96484375 6551
36 0.9 20 5.37890625 12323
38 0.95 20 4.76953125 51477
39 0.5 20 5.44140625 47831
41 0.95 20 4.62109375 23590
42 0.6 20 5.04296875 39665
45 0.3 20 3.607421875 11512
46 0.45 20 3.94921875 28822
47 0.85 20 4.42578125 11030
48 0.65 20 4.6875 32988
49 0.65 20 4.53515625 10653
50 0.4 20 5.20703125 25080
52 0.5 20 5.02734375 12277
53 1.0 20 4.828125 7314
54 0.9 20 5.05859375 27096
55 0.65 20 4.66015625 18974
56 0.0 20 