Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Watermarking LogitsProcessor and WatermarkDetector #29676

Merged
merged 47 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
9af1d2c
add watermarking processor
zucchini-nlp Mar 15, 2024
92a5214
remove the other hashing (context width=1 always)
zucchini-nlp Mar 15, 2024
3b2c6da
make style
zucchini-nlp Mar 15, 2024
bd1a8aa
Update src/transformers/generation/logits_process.py
zucchini-nlp Mar 15, 2024
3756540
Update src/transformers/generation/logits_process.py
zucchini-nlp Mar 15, 2024
c67069b
Update src/transformers/generation/logits_process.py
zucchini-nlp Mar 15, 2024
52b58cf
Update src/transformers/generation/configuration_utils.py
zucchini-nlp Mar 15, 2024
e4c92b8
update watermarking process
zucchini-nlp Mar 18, 2024
e77ea5e
add detector
zucchini-nlp Mar 18, 2024
ab8f79f
update tests to use detector
zucchini-nlp Mar 18, 2024
bd4b875
fix failing tests
zucchini-nlp Mar 19, 2024
9bf52b3
Merge remote-tracking branch 'upstream/main' into watermark
zucchini-nlp Mar 19, 2024
e2e689b
rename `input_seq`
zucchini-nlp Mar 19, 2024
6dd8eb3
make style
zucchini-nlp Mar 19, 2024
5ba45c0
doc for processor
zucchini-nlp Mar 19, 2024
f9c6594
minor fixes
zucchini-nlp Mar 19, 2024
0597a17
docs
zucchini-nlp Mar 21, 2024
d4f5de1
Merge remote-tracking branch 'upstream/main' into watermark
zucchini-nlp Mar 21, 2024
77d8745
make quality
zucchini-nlp Mar 21, 2024
8cc4453
Merge remote-tracking branch 'upstream/main' into watermark
zucchini-nlp Mar 22, 2024
82f0853
Update src/transformers/generation/configuration_utils.py
zucchini-nlp Mar 26, 2024
1216142
Update src/transformers/generation/logits_process.py
zucchini-nlp Mar 26, 2024
5e671e0
Update src/transformers/generation/watermarking.py
zucchini-nlp Mar 26, 2024
f50e945
Update src/transformers/generation/watermarking.py
zucchini-nlp Mar 26, 2024
c1c9ed8
Update src/transformers/generation/watermarking.py
zucchini-nlp Mar 26, 2024
2055f56
Merge remote-tracking branch 'upstream/main' into watermark
zucchini-nlp Mar 26, 2024
3578150
add PR suggestions
zucchini-nlp Mar 27, 2024
b477eb5
let's use lru_cache's default max size (128)
zucchini-nlp Mar 27, 2024
cab4969
import processor if torch available
zucchini-nlp Mar 27, 2024
c03e752
maybe like this
zucchini-nlp Mar 27, 2024
b28f646
lets move the config to torch independet file
zucchini-nlp Mar 27, 2024
966808d
add docs
zucchini-nlp Mar 27, 2024
2d0c3e3
tiny docs fix to make the test happy
zucchini-nlp Apr 1, 2024
8223376
Update src/transformers/generation/configuration_utils.py
zucchini-nlp Apr 3, 2024
f33a3a2
Update src/transformers/generation/watermarking.py
zucchini-nlp Apr 3, 2024
6e60d32
PR suggestions
zucchini-nlp Apr 3, 2024
7ae9ae9
add docs
zucchini-nlp Apr 3, 2024
863663c
fix test
zucchini-nlp Apr 5, 2024
76a66b5
Merge remote-tracking branch 'upstream/main' into watermark
zucchini-nlp Apr 5, 2024
177c765
fix docs
zucchini-nlp Apr 5, 2024
e6da307
Merge branch 'huggingface:main' into watermark
zucchini-nlp Apr 18, 2024
f036d49
address pr comments
zucchini-nlp May 8, 2024
1c3e987
Merge branch 'main' into watermark
zucchini-nlp May 9, 2024
7f33cc3
style
zucchini-nlp May 9, 2024
5e70bab
Revert "style"
zucchini-nlp May 9, 2024
3be20e3
correct style
zucchini-nlp May 9, 2024
f15935c
make doctest green
zucchini-nlp May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,6 +1410,7 @@
"TopPLogitsWarper",
"TypicalLogitsWarper",
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
"WatermarkLogitsProcessor",
"WhisperTimeStampLogitsProcessor",
]
)
Expand Down Expand Up @@ -6207,6 +6208,7 @@
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
WhisperTimeStampLogitsProcessor,
)
from .modeling_utils import PreTrainedModel
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
"TypicalLogitsWarper",
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
"WhisperTimeStampLogitsProcessor",
"WatermarkLogitsProcessor",
]
_import_structure["stopping_criteria"] = [
"MaxNewTokensCriteria",
Expand Down Expand Up @@ -213,6 +214,7 @@
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
WhisperTimeStampLogitsProcessor,
)
from .stopping_criteria import (
Expand Down
19 changes: 18 additions & 1 deletion src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,19 @@ class GenerationConfig(PushToHubMixin):
low_memory (`bool`, *optional*):
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
Used with beam search and contrastive search.

watermark (`bool`, *optional*):
Watermark the model outputs by adding a small bias to randomly selected set of "green" tokens.
greenlist_ratio (`float`, *optional*):
Used for watermaring. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
watermark_bias (`float`, *optional*):
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
hashing_key (`int`, *optional*):
Hahsing key used for watermarking. Defaults to 15485863 (the millionth prime).
seeding_scheme (`str`, *optional*):
Algorithm to use for watermarking. Accepts values:
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from paper)
- "selfhash": "green" tokens selection depends ono the current token itself (Algorithm 3 from paper)
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved

> Parameters that define the output variables of `generate`

Expand Down Expand Up @@ -340,6 +352,11 @@ def __init__(self, **kwargs):
self.sequence_bias = kwargs.pop("sequence_bias", None)
self.guidance_scale = kwargs.pop("guidance_scale", None)
self.low_memory = kwargs.pop("low_memory", None)
self.watermark = kwargs.pop("watermark", False)
self.greenlist_ratio = kwargs.pop("greenlist_ratio", 0.25)
self.watermark_bias = kwargs.pop("watermark_bias", 2.0)
self.hashing_key = kwargs.pop("hashing_key", 15485863)
self.seeding_scheme = kwargs.pop("seeding_scheme", "lefthash")
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved

# Parameters that define the output variables of `generate`
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
Expand Down
123 changes: 123 additions & 0 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import torch
import torch.nn.functional as F

from ..utils import add_start_docstrings
from ..utils.logging import get_logger
Expand Down Expand Up @@ -2215,3 +2216,125 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
scores = torch.where(do_early_stop, early_stop_scores, scores)

return scores


class WatermarkLogitsProcessor(LogitsProcessor):
r"""
Logits processor for watermarking generated text. The processor modifies model output scores by adding a small bias to
randomized set of "green" tokens before generating the next token. "Green" tokens selection process depends on the
`seeding_scheme` used.

See [the paper](https://arxiv.org/abs/2301.10226) for more information.
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved

Args:
vocab_size (`int`):
The model tokenizer's vocab_size. Used to calculate "green" tokens ratio.
device (`str`):
The device where model is allocated.
greenlist_ratio (`float`, optional):
The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
bias (`float`, optional):
The bias added to the selected "green" tokens' logits. Consider lowering the
`bias` if the text generation quality degrades. Recommended values are in the
range of [0.5, 2.0]. Defaults to 2.0.
hashing_key (`int`, optional):
Key used for hashing. If you deploy this watermark, we advise using another private key.
Defaults to 15485863 (the millionth prime).
seeding_scheme (`str`, optional):
The seeding scheme used for selecting "green" tokens. Accepts values:
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from paper)
- "selfhash": "green" tokens selection depends ono the current token itself (Algorithm 3 from paper)
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".

Examples:

```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM

>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
>>> inputs = tokenizer(["This is the beginning of a long story"], return_tensors="pt")

>>> # watermarked outputs
>>> out = model.generate(inputs["input_ids"], watermark=True, tokenizer=tokenizer, max_length=20, do_sample=False)
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
"This is the beginning of a long story, but I'll try to keep it short."

>>> # normal generation
>>> out = model.generate(inputs["input_ids"], watermark=False, max_length=20, do_sample=False)
>>> tokenizer.batch_decode(out, skip_special_tokens=True)[0]
"This is the beginning of a long story.\n\nOnce upon a time, there was a"
```
"""

def __init__(
self,
vocab_size,
device,
greenlist_ratio: float = 0.25,
bias: float = 2.0,
hashing_key: int = 15485863,
seeding_scheme: str = "lefthash",
):
if seeding_scheme not in ["selfhash", "lefthash"]:
raise ValueError(f"seeding_scheme has to be one of [`selfhash`, `lefthash`], but foind {seeding_scheme}")
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
if greenlist_ratio >= 1.0 or greenlist_ratio <= 0.0:
raise ValueError(
f"greenlist_ratio has be in range between 0.0 and 1.0, exclusively. but found {greenlist_ratio}"
)

self.vocab_size = vocab_size
self.greenlist_size = int(self.vocab_size * greenlist_ratio)
self.bias = bias
self.seeding_scheme = seeding_scheme
self.rng = torch.Generator(device=device)
self.hash_key = hashing_key

self.rng.manual_seed(hashing_key)
self.table_size = 1_000_003
self.fixed_table = torch.randperm(self.table_size, generator=self.rng, device=device)
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved

def set_seed(self, input_ids: torch.LongTensor):
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
seed = self.hash_key * input_ids[-1].item()
self.rng.manual_seed(seed % (2**64 - 1))

def _get_greenlist_ids(self, input_ids: torch.LongTensor) -> torch.LongTensor:
self.set_seed(input_ids)
vocab_permutation = torch.randperm(self.vocab_size, device=input_ids.device, generator=self.rng)
greenlist_ids = vocab_permutation[: self.greenlist_size]
return greenlist_ids

def _score_rejection_sampling(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.LongTensor:
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
"""
Generate greenlist based on current candidate next token. Reject and move on if necessary.
Runs for a fixed number of steps only for efficiency, since the methods is not batched.
"""
final_greenlist = []
_, greedy_predictions = scores.sort(dim=-1, descending=True)
for i in range(40):
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
greenlist_ids = self._get_greenlist_ids(torch.cat([input_ids, greedy_predictions[i, None]], dim=-1))
if greedy_predictions[i] in greenlist_ids:
final_greenlist.append(greedy_predictions[i])
return torch.tensor(final_greenlist, device=input_ids.device)

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
greenlist_token_ids = torch.empty(
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
scores.shape[0], self.greenlist_size, device=scores.device, dtype=torch.int64
)
for b_idx, input_seq in enumerate(input_ids):
if self.seeding_scheme == "selfhash":
greenlist_ids = self._score_rejection_sampling(input_seq, scores[b_idx])
else:
greenlist_ids = self._get_greenlist_ids(input_ids=input_seq)

# Greenlists could differ in length in selfhash, so we pad it by duplicating the last token
if greenlist_ids.shape[-1] < greenlist_token_ids.shape[-1]:
max_diff = greenlist_token_ids.shape[-1] - greenlist_ids.shape[-1]
greenlist_ids = F.pad(greenlist_ids, (0, max_diff), value=greenlist_ids[-1])
greenlist_token_ids[b_idx] = greenlist_ids

green_tokens_mask = torch.full_like(scores, False, dtype=torch.bool)
batch_indices = torch.arange(scores.shape[0]).unsqueeze(1)
green_tokens_mask[batch_indices, greenlist_token_ids] = True
scores[green_tokens_mask] = scores[green_tokens_mask] + self.bias
return scores
23 changes: 23 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
)
from .stopping_criteria import (
MaxLengthCriteria,
Expand All @@ -85,6 +86,7 @@

if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..tokenization_utils_base import PreTrainedTokenizerBase
from .streamers import BaseStreamer

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -755,6 +757,8 @@ def _get_logits_processor(
encoder_input_ids: torch.LongTensor,
prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]],
logits_processor: Optional[LogitsProcessorList],
device: str = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -871,6 +875,22 @@ def _get_logits_processor(
FutureWarning,
)
processors.append(ForceTokensLogitsProcessor(generation_config.forced_decoder_ids, _has_warned=True))
if generation_config.watermark:
if tokenizer is None:
raise ValueError(
"Generation config's 'watermark' is set to `True` but tokenizer not found. "
"Pass the model's tokenizer as input to `generate`."
)
processors.append(
WatermarkLogitsProcessor(
vocab_size=tokenizer.vocab_size,
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
device=device,
greenlist_ratio=generation_config.greenlist_ratio,
bias=generation_config.watermark_bias,
hashing_key=generation_config.hashing_key,
seeding_scheme=generation_config.seeding_scheme,
)
)
processors = self._merge_criteria_processor_list(processors, logits_processor)
# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
Expand Down Expand Up @@ -1321,6 +1341,7 @@ def generate(
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None)
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())

Expand Down Expand Up @@ -1474,6 +1495,8 @@ def generate(
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
tokenizer=tokenizer,
device=inputs_tensor.device,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope passing a device into the processor directly, works for the multi-gpu generate.

Actually, this is quite handy to be able to init tensors in their devices, while init the processor. Especially when we make compile compatible processors, where we already moved to init some of arguments in tensor format.

model_kwargs=model_kwargs,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
Expand Down
24 changes: 24 additions & 0 deletions tests/generation/test_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
)
from transformers.generation.logits_process import BarkEosPrioritizerLogitsProcessor

Expand Down Expand Up @@ -840,3 +841,26 @@ def test_early_stop_processor_multi_eos(self):
[float("-inf"), float("-inf"), scores[0][0], scores[0][0]],
]
self.assertListEqual(actual_scores.tolist(), expected_scores_list)

def test_watermarking_processor(self):
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
batch_size = 3
vocab_size = 20

input_ids = ids_tensor((batch_size, 5), vocab_size=20)
scores = self._get_uniform_logits(batch_size, vocab_size)

# raise error if incorrect seeding_scheme is passed
with self.assertRaises(ValueError):
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", seeding_scheme="hash")

# raise error if the greenlist_ratio in not in range (0.0, 1.0)
with self.assertRaises(ValueError):
WatermarkLogitsProcessor(vocab_size=vocab_size, device="cpu", greenlist_ratio=1.2)

watermark = WatermarkLogitsProcessor(vocab_size=vocab_size, device=input_ids.device)

# use fixed id for last token, needed for reprodicibility and tests
input_ids[:, -1] = 10
scores_wo_bias = scores[:, -1].clone()
out = watermark(input_ids=input_ids, scores=scores)
self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all())
21 changes: 21 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2805,6 +2805,27 @@ def test_beam_search_low_memory(self):
)
self.assertListEqual(low_output.tolist(), high_output.tolist())

@slow
def test_watermark_generation(self):
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
tokenizer.pad_token_id = tokenizer.eos_token_id
model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device)

output = model.generate(**model_inputs, watermark=True, tokenizer=tokenizer, do_sample=False, max_length=10)
output_selfhash = model.generate(
**model_inputs,
watermark=True,
tokenizer=tokenizer,
seeding_scheme="selfhash",
do_sample=False,
max_length=10,
)

# as long as we use the same inputs, hashing key and device, we can be sure that output are deterministic when greedy decoding
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
self.assertListEqual(output.tolist(), [[40, 481, 307, 736, 2582, 553, 262, 1893, 531, 13]])
self.assertListEqual(output_selfhash.tolist(), [[40, 481, 307, 262, 717, 530, 284, 9159, 326, 262]])

@slow
def test_beam_search_example_integration(self):
# PT-only test: TF doesn't have a BeamSearchScorer
Expand Down
Loading