Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Watermarking LogitsProcessor and WatermarkDetector #29676

Merged
merged 47 commits into from
May 14, 2024

Conversation

zucchini-nlp
Copy link
Member

What does this PR do?

Adds a watermarking technique proposed in this paper to transformers logits processor. I added only the simple method (algorithm 2 from paper) and the robust one (algorithm 3), both with context length of 1 token only. I am not sure if we should support higher context width, defined by user in generation config.

In contrast to the original repo, masking now is done in batched manner. Yet, I could not make the _get_greenlist_ids batched, so we are still left with a loop over batches one by one...

Note, this is only a processor for generation with watermarking. Anyone who uses it and wants later to detect the watermarked text, has to use the Detector from the original repo, using their own private hashing keys.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@gante

@@ -1474,6 +1495,8 @@ def generate(
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
tokenizer=tokenizer,
device=inputs_tensor.device,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope passing a device into the processor directly, works for the multi-gpu generate.

Actually, this is quite handy to be able to init tensors in their devices, while init the processor. Especially when we make compile compatible processors, where we already moved to init some of arguments in tensor format.

@zucchini-nlp
Copy link
Member Author

cc @JonasGeiping @jwkirchenbauer

I would love to have your feedback on the default values we chose to use 🤗

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@JonasGeiping
Copy link

Hi, great work! Supporting a default context length of 4 would be something I would really advocate for, based on results such as Fig.2 in https://arxiv.org/abs/2306.04634.

A separate concern is execution speed. We hooked the watermark into cuda.rng.manual_seed, which is convenient, but not actually an efficient approach, and cannot be batched. A future-proof way of doing this would probably circumvent CUDA alltogether, and include in a different implementation of [list_of_integers + salt] -> hash -> pseudorandom green/red partition table, but we also didn't do that and I am not sure what your time frame for this feature is.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR won't be complete without the Detector -- without it, our users will struggle to use the watermarking techniques and we will struggle to test their correctness.

I'd suggest adding the Detector in a new file in the generation folder (watermarking.py). In their docstrings we could add usage examples :)

src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/logits_process.py Outdated Show resolved Hide resolved
src/transformers/generation/logits_process.py Show resolved Hide resolved
src/transformers/generation/logits_process.py Outdated Show resolved Hide resolved
src/transformers/generation/logits_process.py Outdated Show resolved Hide resolved
src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/logits_process.py Show resolved Hide resolved
src/transformers/generation/logits_process.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
zucchini-nlp and others added 4 commits March 15, 2024 21:12
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
@zucchini-nlp
Copy link
Member Author

@JonasGeiping Thanks for the feedback! Yes, indeed the higher context width has better performance. I had reluctance adding more complexity to the code when opening PR, but now that we are okay with adding the whole watermarking functionality, I will add possibility for users to set their own context width.

Yes, different implementation of rng which works in batched form would be very nice to have. Right now I am not planning to work on it, and I prefer to leave it for future plans if we see active usage of the watermarking feature 😄

@JonasGeiping
Copy link

complexity could be reduced a bit by removing self-hashing for now. This setting has several implementation complexities, and without efficient RNG is quite slow to use it during text-gen for a purpose that is not testing watermark quality.

@zucchini-nlp
Copy link
Member Author

@gante , where can we add a doc for the detector? The tests are failing otherwise.

@gante
Copy link
Member

gante commented Mar 20, 2024

@zucchini-nlp here -- https://github.com/huggingface/transformers/blob/main/docs/source/en/internal/generation_utils.md

(ping me again when it's ready for a review :) )

@zucchini-nlp
Copy link
Member Author

@gante sorry, forgot to tag. Yes, ready to review. Added a config for watermark args, changed cache size and rewrote some docs.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The missing bits are nits, so I'm approving the PR. Well done, this is a cool feature 🔥

Missing: somewhere in the docs (in addition to the API docs) showcasing this feature. Perhaps a section below this one?

src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
@@ -17,7 +17,7 @@

import numpy as np

from ... import is_vision_available, requires_backends
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect these changes come from another PR :p

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they do, i need to rebase main after it's merged

src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/watermarking.py Outdated Show resolved Hide resolved
src/transformers/generation/watermarking.py Outdated Show resolved Hide resolved

"""

# Let's assume that if one batch start with `bos`, all batched also do
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: what happens with left-padding? Does left-padding have an impact on the detector?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I did not think someone would feed left-padded text. What I usually do if feeding only the generated text part, excluding the prompt.
Prompt itself has some effect on the z-score but in my toy examples the final prediction did not change. But I believe long prompt with smaller generated text might be predicted as "human" 🤔

Anyway, I will better explain this in the docs, adding more information next to the "Generation strategies".

src/transformers/generation/watermarking.py Outdated Show resolved Hide resolved
@gante gante requested a review from ArthurZucker April 3, 2024 15:52
@gante
Copy link
Member

gante commented Apr 3, 2024

@zucchini-nlp I'd also edit the PR header, it is outdated :) (for instance, it says users should use the detector from the original repo)

@zucchini-nlp zucchini-nlp changed the title Add Watermarking LogitsProcessor Add Watermarking LogitsProcessor and WatermarkDetector Apr 3, 2024
zucchini-nlp and others added 4 commits April 3, 2024 21:10
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
@ArthurZucker
Copy link
Collaborator

I'll review next week! 🤗

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 5, 2024

Hi, @zucchini-nlp

Running

RUN_SLOW=1 TF_FORCE_GPU_ALLOW_GROWTH=true python3 -m pytest -v tests/generation/test_utils.py::GenerationIntegrationTests::test_watermark_generation

gives

ValueError: The following `model_kwargs` are not used by the model: ['watermarking_args'] (note: typos in the generate arguments will also show up in this list)

Could you look into this?

Full error log

self = GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (dro...((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)
model_kwargs = {'attention_mask': tensor([[1, 1, 1]], device='cuda:0'), 'input_ids': tensor([[ 40, 481, 307]], device='cuda:0'), 'watermarking_args': {'bias': 2.0, 'context_width': 1, 'greenlist_ratio': 0.25, 'hashing_key': 15485863, ...}}

    def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
        """Validates model kwargs for generation. Generate argument typos will also be caught here."""
        # If a `Cache` instance is passed, checks whether the model is compatible with it
        if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class:
            raise ValueError(
                f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please "
                "check the model documentation for supported cache formats."
            )
    
        # Excludes arguments that are handled before calling any model function
        if self.config.is_encoder_decoder:
            for key in ["decoder_input_ids"]:
                model_kwargs.pop(key, None)
    
        unused_model_args = []
        model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
        # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
        # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
        if "kwargs" in model_args or "model_kwargs" in model_args:
            model_args |= set(inspect.signature(self.forward).parameters)
    
        # Encoder-Decoder models may also need Encoder arguments from `model_kwargs`
        if self.config.is_encoder_decoder:
            base_model = getattr(self, self.base_model_prefix, None)
    
            # allow encoder kwargs
            encoder = getattr(self, "encoder", None)
            # `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`.
            # Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder`
            # TODO: A better way to handle this.
            if encoder is None and base_model is not None:
                encoder = getattr(base_model, "encoder", None)
    
            if encoder is not None:
                encoder_model_args = set(inspect.signature(encoder.forward).parameters)
                model_args |= encoder_model_args
    
            # allow decoder kwargs
            decoder = getattr(self, "decoder", None)
            if decoder is None and base_model is not None:
                decoder = getattr(base_model, "decoder", None)
    
            if decoder is not None:
                decoder_model_args = set(inspect.signature(decoder.forward).parameters)
                model_args |= {f"decoder_{x}" for x in decoder_model_args}
    
            # allow assistant_encoder_outputs to be passed if we're doing assisted generating
            if "assistant_encoder_outputs" in model_kwargs:
                model_args |= {"assistant_encoder_outputs"}
    
        for key, value in model_kwargs.items():
            if value is not None and key not in model_args:
                unused_model_args.append(key)
    
        if unused_model_args:
>           raise ValueError(
                f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
                " generate arguments will also show up in this list)"
            )
E           ValueError: The following `model_kwargs` are not used by the model: ['watermarking_args'] (note: typos in the generate arguments will also show up in this list)

src/transformers/generation/utils.py:1136: ValueError

@zucchini-nlp
Copy link
Member Author

@ydshieh my bad, did not fix tests after latest changes

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 5, 2024

No worry. (But it is always nice to check the tests once we think a PR is ready at some point 😄 )
Thanks for the fixing!

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 5, 2024

The following objects docstrings do not match their signature. Run make fix-copies to fix this. In some cases, this error may be raised incorrectly by the docstring checker. If you think this is the case, you can manually check the docstrings

  • WatermarkDetector

We need to check the docstrings for WatermarkDetector. You can first run make fix-copies but don't apply the changes blindly - check them and make a decision 🙏

@zucchini-nlp
Copy link
Member Author

@ArthurZucker ping

@ArthurZucker
Copy link
Collaborator

Thanks for the ping, reviewing now!

@ArthurZucker
Copy link
Collaborator

OUps on it today sorry @zucchini-nlp

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for working on this interesting feature!
I think the doc can be improved a bit, but otherwise very clean

docs/source/en/generation_strategies.md Outdated Show resolved Hide resolved
Comment on lines +199 to +206
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
>>> tok.pad_token_id = tok.eos_token_id
>>> tok.padding_side = "left"

>>> inputs = tok(["This is the beginning of a long story", "Alice and Bob are"], padding=True, return_tensors="pt")
>>> input_len = inputs["input_ids"].shape[-1]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from this snippet I have no idea what the green and red is, no idea what the prediction says, Truem True?
Is the detector detecting watermarking? What is it detecting etc. Think this needs to be improved!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add a bit more info, but for full understanding it is better to read the paper. I will give a very brief overview of the general idea behind the tecnique :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Summing up without the user or me to dig into the paper is nice

src/transformers/generation/logits_process.py Show resolved Hide resolved
Comment on lines 2276 to 2283
```
>>> # to detect watermarked text use the WatermarkDetector class
>>> from transformers import WatermarkDetector
>>> detector = WatermarkDetector(model_config=model.config, device="cpu", watermarking_config= watermarking_config)
>>> detection_preds = detector(out)
>>> detection_preds
array([ True])
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring seems a bit wrong

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(the code blkock stops while it should not

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example is a bit better! But would be nice to explain in the md that a bias is added to the logits of both, which makes the next token generated be "stilleinstead ofin` (I am saying this as an example, but what actually happend?). which words are green in this example?)

Also could we manually set green words?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is a randomized internal process of "green" token selection, which is set by indicating a hashing fn and a hash key. This two can be later used to reverse the process and check how many green tokens the generated text contains, and if it's statistically likely for a human-generated text to have this proportion of green tokens.

Not sure if we have to give the whole explanation of the algorithm or just refer to the paper though 🤔

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's give a brief explanation, it's better to sum up and also link to the paper, but for both me (the reviewer) and any curious user, we don't want to go and read everything!

)
return num_tokens_scored_batch, green_token_count_batch

def _compute_z_score(self, green_token_count: np.array, total_num_tokens: np.array) -> np.array:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is a z score? Where does it come from?

Comment on lines +113 to +117
>>> detection_out_watermarked.prediction
array([ True, True])

>>> detection_out.prediction
array([False, False])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's always going to be watermarked in batches for the generate right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be unbatched also, depends on what is passed as the input so that it just works as a simple add-on on the generate. The logits_process.py has a one element watermarking in its docstring

tests/generation/test_logits_process.py Show resolved Hide resolved
Comment on lines 49 to 50
z_score (np.array of shape (batch_size)):
Array containing the z-score for each batch.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a more reprensentative name for that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a term from stats, so I guess we cannot just call it how we want. I added a small explanation of what is z-score

src/transformers/generation/watermarking.py Outdated Show resolved Hide resolved
@zucchini-nlp
Copy link
Member Author

The PR seems ready for me, all comments are addressed and the tests are passing. I see that I got two approvals, but I will leave it here until next week May 13 in case anyone wants to add something 😄

@zucchini-nlp zucchini-nlp merged commit 5ad960f into huggingface:main May 14, 2024
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants