-
Notifications
You must be signed in to change notification settings - Fork 25.2k
/
watermarking.py
240 lines (200 loc) 路 10.7 KB
/
watermarking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from dataclasses import dataclass
from functools import lru_cache
from typing import Dict, Optional, Union
import numpy as np
from ..configuration_utils import PretrainedConfig
from ..utils import is_torch_available, logging
from .configuration_utils import WatermarkingConfig
if is_torch_available():
import torch
from .logits_process import WatermarkLogitsProcessor
logger = logging.get_logger(__name__)
@dataclass
class WatermarkDetectorOutput:
"""
Outputs of a watermark detector.
Args:
num_tokens_scored (np.array of shape (batch_size)):
Array containing the number of tokens scored for each element in the batch.
num_green_tokens (np.array of shape (batch_size)):
Array containing the number of green tokens for each element in the batch.
green_fraction (np.array of shape (batch_size)):
Array containing the fraction of green tokens for each element in the batch.
z_score (np.array of shape (batch_size)):
Array containing the z-score for each element in the batch. Z-score here shows
how many standard deviations away is the green token count in the input text
from the expected green token count for machine-generated text.
p_value (np.array of shape (batch_size)):
Array containing the p-value for each batch obtained from z-scores.
prediction (np.array of shape (batch_size)), *optional*:
Array containing boolean predictions whether a text is machine-generated for each element in the batch.
confidence (np.array of shape (batch_size)), *optional*:
Array containing confidence scores of a text being machine-generated for each element in the batch.
"""
num_tokens_scored: np.array = None
num_green_tokens: np.array = None
green_fraction: np.array = None
z_score: np.array = None
p_value: np.array = None
prediction: Optional[np.array] = None
confidence: Optional[np.array] = None
class WatermarkDetector:
r"""
Detector for detection of watermark generated text. The detector needs to be given the exact same settings that were
given during text generation to replicate the watermark greenlist generation and so detect the watermark. This includes
the correct device that was used during text generation, the correct watermarking arguments and the correct tokenizer vocab size.
The code was based on the [original repo](https://github.com/jwkirchenbauer/lm-watermarking/tree/main).
See [the paper](https://arxiv.org/abs/2306.04634) for more information.
Args:
model_config (`PretrainedConfig`):
The model config that will be used to get model specific arguments used when generating.
device (`str`):
The device which was used during watermarked text generation.
watermarking_config (Union[`WatermarkingConfig`, `Dict`]):
The exact same watermarking config and arguments used when generating text.
ignore_repeated_ngrams (`bool`, *optional*, defaults to `False`):
Whether to count every unique ngram only once or not.
max_cache_size (`int`, *optional*, defaults to 128):
The max size to be used for LRU caching of seeding/sampling algorithms called for every token.
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkDetector, WatermarkingConfig
>>> model_id = "openai-community/gpt2"
>>> model = AutoModelForCausalLM.from_pretrained(model_id)
>>> tok = AutoTokenizer.from_pretrained(model_id)
>>> tok.pad_token_id = tok.eos_token_id
>>> tok.padding_side = "left"
>>> inputs = tok(["This is the beginning of a long story", "Alice and Bob are"], padding=True, return_tensors="pt")
>>> input_len = inputs["input_ids"].shape[-1]
>>> # first generate text with watermark and without
>>> watermarking_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
>>> out_watermarked = model.generate(**inputs, watermarking_config=watermarking_config, do_sample=False, max_length=20)
>>> out = model.generate(**inputs, do_sample=False, max_length=20)
>>> # now we can instantiate the detector and check the generated text
>>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config=watermarking_config)
>>> detection_out_watermarked = detector(out_watermarked, return_dict=True)
>>> detection_out = detector(out, return_dict=True)
>>> detection_out_watermarked.prediction
array([ True, True])
>>> detection_out.prediction
array([False, False])
```
"""
def __init__(
self,
model_config: PretrainedConfig,
device: str,
watermarking_config: Union[WatermarkingConfig, Dict],
ignore_repeated_ngrams: bool = False,
max_cache_size: int = 128,
):
if isinstance(watermarking_config, WatermarkingConfig):
watermarking_config = watermarking_config.to_dict()
self.bos_token_id = (
model_config.bos_token_id if not model_config.is_encoder_decoder else model_config.decoder_start_token_id
)
self.greenlist_ratio = watermarking_config["greenlist_ratio"]
self.ignore_repeated_ngrams = ignore_repeated_ngrams
self.processor = WatermarkLogitsProcessor(
vocab_size=model_config.vocab_size, device=device, **watermarking_config
)
# Expensive re-seeding and sampling is cached.
self._get_ngram_score_cached = lru_cache(maxsize=max_cache_size)(self._get_ngram_score)
def _get_ngram_score(self, prefix: torch.LongTensor, target: int):
greenlist_ids = self.processor._get_greenlist_ids(prefix)
return target in greenlist_ids
def _score_ngrams_in_passage(self, input_ids: torch.LongTensor):
batch_size, seq_length = input_ids.shape
selfhash = int(self.processor.seeding_scheme == "selfhash")
n = self.processor.context_width + 1 - selfhash
indices = torch.arange(n).unsqueeze(0) + torch.arange(seq_length - n + 1).unsqueeze(1)
ngram_tensors = input_ids[:, indices]
num_tokens_scored_batch = np.zeros(batch_size)
green_token_count_batch = np.zeros(batch_size)
for batch_idx in range(ngram_tensors.shape[0]):
frequencies_table = collections.Counter(ngram_tensors[batch_idx])
ngram_to_watermark_lookup = {}
for ngram_example in frequencies_table.keys():
prefix = ngram_example if selfhash else ngram_example[:-1]
target = ngram_example[-1]
ngram_to_watermark_lookup[ngram_example] = self._get_ngram_score_cached(prefix, target)
if self.ignore_repeated_ngrams:
# counts a green/red hit once per unique ngram.
# num total tokens scored becomes the number unique ngrams.
num_tokens_scored_batch[batch_idx] = len(frequencies_table.keys())
green_token_count_batch[batch_idx] = sum(ngram_to_watermark_lookup.values())
else:
num_tokens_scored_batch[batch_idx] = sum(frequencies_table.values())
green_token_count_batch[batch_idx] = sum(
freq * outcome
for freq, outcome in zip(frequencies_table.values(), ngram_to_watermark_lookup.values())
)
return num_tokens_scored_batch, green_token_count_batch
def _compute_z_score(self, green_token_count: np.array, total_num_tokens: np.array) -> np.array:
expected_count = self.greenlist_ratio
numer = green_token_count - expected_count * total_num_tokens
denom = np.sqrt(total_num_tokens * expected_count * (1 - expected_count))
z = numer / denom
return z
def _compute_pval(self, x, loc=0, scale=1):
z = (x - loc) / scale
return 1 - (0.5 * (1 + np.sign(z) * (1 - np.exp(-2 * z**2 / np.pi))))
def __call__(
self,
input_ids: torch.LongTensor,
z_threshold: float = 3.0,
return_dict: bool = False,
) -> Union[WatermarkDetectorOutput, np.array]:
"""
Args:
input_ids (`torch.LongTensor`):
The watermark generated text. It is advised to remove the prompt, which can affect the detection.
z_threshold (`Dict`, *optional*, defaults to `3.0`):
Changing this threshold will change the sensitivity of the detector. Higher z threshold gives less
sensitivity and vice versa for lower z threshold.
return_dict (`bool`, *optional*, defaults to `False`):
Whether to return `~generation.WatermarkDetectorOutput` or not. If not it will return boolean predictions,
ma
Return:
[`~generation.WatermarkDetectorOutput`] or `np.array`: A [`~generation.WatermarkDetectorOutput`]
if `return_dict=True` otherwise a `np.array`.
"""
# Let's assume that if one batch start with `bos`, all batched also do
if input_ids[0, 0] == self.bos_token_id:
input_ids = input_ids[:, 1:]
if input_ids.shape[-1] - self.processor.context_width < 1:
raise ValueError(
f"Must have at least `1` token to score after the first "
f"min_prefix_len={self.processor.context_width} tokens required by the seeding scheme."
)
num_tokens_scored, green_token_count = self._score_ngrams_in_passage(input_ids)
z_score = self._compute_z_score(green_token_count, num_tokens_scored)
prediction = z_score > z_threshold
if return_dict:
p_value = self._compute_pval(z_score)
confidence = 1 - p_value
return WatermarkDetectorOutput(
num_tokens_scored=num_tokens_scored,
num_green_tokens=green_token_count,
green_fraction=green_token_count / num_tokens_scored,
z_score=z_score,
p_value=p_value,
prediction=prediction,
confidence=confidence,
)
return prediction