Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Watermarking LogitsProcessor and WatermarkDetector #29676

Merged
merged 47 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
9af1d2c
add watermarking processor
zucchini-nlp Mar 15, 2024
92a5214
remove the other hashing (context width=1 always)
zucchini-nlp Mar 15, 2024
3b2c6da
make style
zucchini-nlp Mar 15, 2024
bd1a8aa
Update src/transformers/generation/logits_process.py
zucchini-nlp Mar 15, 2024
3756540
Update src/transformers/generation/logits_process.py
zucchini-nlp Mar 15, 2024
c67069b
Update src/transformers/generation/logits_process.py
zucchini-nlp Mar 15, 2024
52b58cf
Update src/transformers/generation/configuration_utils.py
zucchini-nlp Mar 15, 2024
e4c92b8
update watermarking process
zucchini-nlp Mar 18, 2024
e77ea5e
add detector
zucchini-nlp Mar 18, 2024
ab8f79f
update tests to use detector
zucchini-nlp Mar 18, 2024
bd4b875
fix failing tests
zucchini-nlp Mar 19, 2024
9bf52b3
Merge remote-tracking branch 'upstream/main' into watermark
zucchini-nlp Mar 19, 2024
e2e689b
rename `input_seq`
zucchini-nlp Mar 19, 2024
6dd8eb3
make style
zucchini-nlp Mar 19, 2024
5ba45c0
doc for processor
zucchini-nlp Mar 19, 2024
f9c6594
minor fixes
zucchini-nlp Mar 19, 2024
0597a17
docs
zucchini-nlp Mar 21, 2024
d4f5de1
Merge remote-tracking branch 'upstream/main' into watermark
zucchini-nlp Mar 21, 2024
77d8745
make quality
zucchini-nlp Mar 21, 2024
8cc4453
Merge remote-tracking branch 'upstream/main' into watermark
zucchini-nlp Mar 22, 2024
82f0853
Update src/transformers/generation/configuration_utils.py
zucchini-nlp Mar 26, 2024
1216142
Update src/transformers/generation/logits_process.py
zucchini-nlp Mar 26, 2024
5e671e0
Update src/transformers/generation/watermarking.py
zucchini-nlp Mar 26, 2024
f50e945
Update src/transformers/generation/watermarking.py
zucchini-nlp Mar 26, 2024
c1c9ed8
Update src/transformers/generation/watermarking.py
zucchini-nlp Mar 26, 2024
2055f56
Merge remote-tracking branch 'upstream/main' into watermark
zucchini-nlp Mar 26, 2024
3578150
add PR suggestions
zucchini-nlp Mar 27, 2024
b477eb5
let's use lru_cache's default max size (128)
zucchini-nlp Mar 27, 2024
cab4969
import processor if torch available
zucchini-nlp Mar 27, 2024
c03e752
maybe like this
zucchini-nlp Mar 27, 2024
b28f646
lets move the config to torch independet file
zucchini-nlp Mar 27, 2024
966808d
add docs
zucchini-nlp Mar 27, 2024
2d0c3e3
tiny docs fix to make the test happy
zucchini-nlp Apr 1, 2024
8223376
Update src/transformers/generation/configuration_utils.py
zucchini-nlp Apr 3, 2024
f33a3a2
Update src/transformers/generation/watermarking.py
zucchini-nlp Apr 3, 2024
6e60d32
PR suggestions
zucchini-nlp Apr 3, 2024
7ae9ae9
add docs
zucchini-nlp Apr 3, 2024
863663c
fix test
zucchini-nlp Apr 5, 2024
76a66b5
Merge remote-tracking branch 'upstream/main' into watermark
zucchini-nlp Apr 5, 2024
177c765
fix docs
zucchini-nlp Apr 5, 2024
e6da307
Merge branch 'huggingface:main' into watermark
zucchini-nlp Apr 18, 2024
f036d49
address pr comments
zucchini-nlp May 8, 2024
1c3e987
Merge branch 'main' into watermark
zucchini-nlp May 9, 2024
7f33cc3
style
zucchini-nlp May 9, 2024
5e70bab
Revert "style"
zucchini-nlp May 9, 2024
3be20e3
correct style
zucchini-nlp May 9, 2024
f15935c
make doctest green
zucchini-nlp May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions docs/source/en/generation_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,47 @@ your screen, one word at a time:
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
```


## Watermarking

The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "red" and "green".
This watermarking strategy was proposed in the paper ["On the Reliability of Watermarks for Large Language Models"](https://arxiv.org/abs/2306.04634).
It can be used with any generative model in `tranformers` and does not require classification model to detect watermarking.

To trigger watermarking, pass in a `WatermarkingConfig` with needed arguments, otherwise initialize `WatermarkingConfig`
without overwriting arguments to use the default values. The watermarked text can be detected by using a `WatermarkDetector`.
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved


<Tip warning={true}>

The `WatermarkDetector` internally relies on the proportion of "green" and "red" tokens, and whether generated text follows the coloring pattern.
That is why it is recommended to strip off the prompt text, if it is much longer than the generated text.
This also can have an effect when one sequence in the batch is a lot longer causing other rows to be padded.

</Tip>


```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, WatermarkDetector, WatermarkingConfig

>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> tok.pad_token_id = tok.eos_token_id
>>> tok.padding_side = "left"

>>> inputs = tok(["This is the beginning of a long story", "Alice and Bob are"], padding=True, return_tensors="pt")
>>> input_len = inputs["input_ids"].shape[-1]

Comment on lines +207 to +214
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from this snippet I have no idea what the green and red is, no idea what the prediction says, Truem True?
Is the detector detecting watermarking? What is it detecting etc. Think this needs to be improved!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add a bit more info, but for full understanding it is better to read the paper. I will give a very brief overview of the general idea behind the tecnique :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Summing up without the user or me to dig into the paper is nice

>>> watermarking_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
>>> out = model.generate(**inputs, watermarking_config=watermarking_config, do_sample=False, max_length=20)

>>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config=watermarking_config)
>>> detection_out = detector(out, return_dict=True)
>>> detection_out.prediction
array([Truem True])
```


## Decoding strategies

Certain combinations of the `generate()` parameters, and ultimately `generation_config`, can be used to enable specific
Expand Down
10 changes: 10 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ generation.
[[autodoc]] WhisperTimeStampLogitsProcessor
- __call__

[[autodoc]] WatermarkLogitsProcessor
- __call__


### TensorFlow

[[autodoc]] TFForcedBOSTokenLogitsProcessor
Expand Down Expand Up @@ -362,3 +366,9 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] StaticCache
- update
- get_seq_length


## Watermark Utils

[[autodoc]] WatermarkDetector
- __call__
2 changes: 2 additions & 0 deletions docs/source/en/main_classes/text_generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ like token streaming.
- validate
- get_generation_mode

[[autodoc]] generation.WatermarkingConfig

## GenerationMixin

[[autodoc]] generation.GenerationMixin
Expand Down
13 changes: 11 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,12 @@
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [],
"generation": ["GenerationConfig", "TextIteratorStreamer", "TextStreamer"],
"generation": [
"GenerationConfig",
"TextIteratorStreamer",
"TextStreamer",
"WatermarkingConfig",
],
"hf_argparser": ["HfArgumentParser"],
"hyperparameter_search": [],
"image_transforms": [],
Expand Down Expand Up @@ -1443,6 +1448,8 @@
"TopPLogitsWarper",
"TypicalLogitsWarper",
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
"WatermarkDetector",
"WatermarkLogitsProcessor",
"WhisperTimeStampLogitsProcessor",
]
)
Expand Down Expand Up @@ -5043,7 +5050,7 @@
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin

# Generation
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer
from .generation import GenerationConfig, TextIteratorStreamer, TextStreamer, WatermarkingConfig
from .hf_argparser import HfArgumentParser

# Integrations
Expand Down Expand Up @@ -6361,6 +6368,8 @@
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkDetector,
WatermarkLogitsProcessor,
WhisperTimeStampLogitsProcessor,
)
from .modeling_utils import PreTrainedModel
Expand Down
14 changes: 12 additions & 2 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


_import_structure = {
"configuration_utils": ["GenerationConfig", "GenerationMode"],
"configuration_utils": ["GenerationConfig", "GenerationMode", "WatermarkingConfig"],
"streamers": ["TextIteratorStreamer", "TextStreamer"],
}

Expand Down Expand Up @@ -77,6 +77,7 @@
"TypicalLogitsWarper",
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
"WhisperTimeStampLogitsProcessor",
"WatermarkLogitsProcessor",
]
_import_structure["stopping_criteria"] = [
"MaxNewTokensCriteria",
Expand Down Expand Up @@ -104,6 +105,10 @@
"GenerateDecoderOnlyOutput",
"GenerateEncoderDecoderOutput",
]
_import_structure["watermarking"] = [
"WatermarkDetector",
"WatermarkDetectorOutput",
]

try:
if not is_tf_available():
Expand Down Expand Up @@ -172,7 +177,7 @@
]

if TYPE_CHECKING:
from .configuration_utils import GenerationConfig, GenerationMode
from .configuration_utils import GenerationConfig, GenerationMode, WatermarkingConfig
from .streamers import TextIteratorStreamer, TextStreamer

try:
Expand Down Expand Up @@ -215,6 +220,7 @@
TopPLogitsWarper,
TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
WhisperTimeStampLogitsProcessor,
)
from .stopping_criteria import (
Expand Down Expand Up @@ -243,6 +249,10 @@
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
)
from .watermarking import (
WatermarkDetector,
WatermarkDetectorOutput,
)

try:
if not is_tf_available():
Expand Down
180 changes: 179 additions & 1 deletion src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import os
import warnings
from dataclasses import dataclass, is_dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

from .. import __version__
Expand Down Expand Up @@ -224,7 +225,23 @@ class GenerationConfig(PushToHubMixin):
low_memory (`bool`, *optional*):
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
Used with beam search and contrastive search.

watermarking_config (Union[`WatermarkingConfig`, `dict`], *optional*):
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green" tokens.
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally.
See [this paper](https://arxiv.org/abs/2306.04634) for more details. Accepts the following keys:
- greenlist_ratio (`float`):
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
- bias (`float`):
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
- hashing_key (`int`):
Hahsing key used for watermarking. Defaults to 15485863 (the millionth prime).
- seeding_scheme (`str`):
Algorithm to use for watermarking. Accepts values:
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
- context_width(`int`):
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.

> Parameters that define the output variables of `generate`

Expand Down Expand Up @@ -340,6 +357,13 @@ def __init__(self, **kwargs):
self.sequence_bias = kwargs.pop("sequence_bias", None)
self.guidance_scale = kwargs.pop("guidance_scale", None)
self.low_memory = kwargs.pop("low_memory", None)
watermarking_config = kwargs.pop("watermarking_config", None)
if watermarking_config is None:
self.watermarking_config = None
elif isinstance(watermarking_config, WatermarkingConfig):
self.watermarking_config = watermarking_config
else:
self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config)

# Parameters that define the output variables of `generate`
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
Expand Down Expand Up @@ -610,6 +634,12 @@ def validate(self, is_init=False):
f"({self.num_beams})."
)

# check watermarking arguments
if self.watermarking_config is not None:
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(self.watermarking_config, WatermarkingConfig):
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
self.watermarking_config.validate()

# 5. check common issue: passing `generate` arguments inside the generation config
generate_arguments = (
"logits_processor",
Expand Down Expand Up @@ -1018,7 +1048,16 @@ def convert_keys_to_string(obj):
else:
return obj

def convert_dataclass_to_dict(obj):
if isinstance(obj, dict):
return {key: convert_dataclass_to_dict(value) for key, value in obj.items()}
elif is_dataclass(obj):
return obj.to_dict()
else:
return obj

config_dict = convert_keys_to_string(config_dict)
config_dict = convert_dataclass_to_dict(config_dict)

return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

Expand Down Expand Up @@ -1090,3 +1129,142 @@ def update(self, **kwargs):
# Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs


@dataclass
class WatermarkingConfig:
"""
Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`.
See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments.

Accepts the following keys:
- greenlist_ratio (`float`):
Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25.
- bias (`float`):
Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0.
- hashing_key (`int`):
Hashing key used for watermarking. Defaults to 15485863 (the millionth prime).
- seeding_scheme (`str`):
Algorithm to use for watermarking. Accepts values:
- "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper)
- "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper)
The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash".
- context_width(`int`):
The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust.
"""

def __init__(
self,
greenlist_ratio: Optional[float] = 0.25,
bias: Optional[float] = 2.0,
hashing_key: Optional[int] = 15485863,
seeding_scheme: Optional[str] = "lefthash",
context_width: Optional[int] = 1,
):
self.greenlist_ratio = greenlist_ratio
self.bias = bias
self.hashing_key = hashing_key
self.seeding_scheme = seeding_scheme
self.context_width = context_width

@classmethod
def from_dict(cls, config_dict, **kwargs):
"""
Constructs a WatermarkingConfig instance from a dictionary of parameters.

Args:
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
**kwargs: Additional keyword arguments to override dictionary values.

Returns:
WatermarkingConfig: Instance of WatermarkingConfig constructed from the dictionary.
"""
config = cls(**config_dict)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
return config

def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save this instance to a JSON file.

Args:
json_file_path (Union[str, os.PathLike]): Path to the JSON file in which this configuration instance's parameters will be saved.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
config_dict = self.to_dict()
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

writer.write(json_string)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.

Returns:
Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
return output

def __iter__(self):
for attr, value in copy.deepcopy(self.__dict__).items():
yield attr, value

def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"

def to_json_string(self):
"""
Serializes this instance to a JSON formatted string.

Returns:
str: JSON formatted string representing the configuration instance.
"""
return json.dumps(self.__dict__, indent=2) + "\n"

def update(self, **kwargs):
"""
Update the configuration attributes with new values.

Args:
**kwargs: Keyword arguments representing configuration attributes and their new values.
"""
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)

def validate(self):
watermark_missing_arg_msg = (
"Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` "
"but found {found_value}"
)
if self.seeding_scheme not in ["selfhash", "lefthash"]:
raise ValueError(
watermark_missing_arg_msg.format(
key="seeding_scheme",
correct_value="[`selfhash`, `lefthash`]",
found_value=self.seeding_scheme,
),
)
if not 0.0 <= self.greenlist_ratio <= 1.0:
raise ValueError(
watermark_missing_arg_msg.format(
key="greenlist_ratio",
correct_value="in range between 0.0 and 1.0",
found_value=self.seeding_scheme,
),
)
if not self.context_width >= 1:
raise ValueError(
watermark_missing_arg_msg.format(
key="context_width",
correct_value="a positive integer",
found_value=self.context_width,
),
)
Loading
Loading