# SynthID Text: Watermarking for Generated Text

This notebook demonstrates how to use the [SynthID Text library][synthid-code]
to apply and detect watermarks on generated text. It is divided into three major
sections and intended to be run end-to-end.

1.  **_Setup_**: Importing the SynthID Text library, choosing your model (either
    [Gemma][gemma] or [GPT-2][gpt2]) and device (either CPU or GPU, depending
    on your runtime), defining the watermarking configuration, and initializing
    some helper functions.
1.  **_Applying a watermark_**: Loading your selected model using the
    [Hugging Face Transformers][transformers] library, using that model to
    generate some watermarked text, and comparing the perplexity of the
    watermarked text to that of text generated by the base model.
1.  **_Detecting a watermark_**: Training a detector to recognize text generated
    with a specific watermarking configuration, and then using that detector to
    predict whether a set of examples were generated with that configuration.

As the reference implementation for the
[SynthID Text paper in _Nature_][synthid-paper], this library and notebook are
intended for research review and reproduction only. They should not be used in
production systems. For a production-grade implementation, check out the
official SynthID logits processor in [Hugging Face Transformers][transformers].

[gemma]: https://ai.google.dev/gemma/docs/model_card
[gpt2]: https://huggingface.co/openai-community/gpt2
[synthid-code]: https://github.com/google-deepmind/synthid-text
[synthid-paper]: https://www.nature.com/
[transformers]: https://huggingface.co/docs/transformers/en/index

# 1. Setup

In [None]:
# @title Install and import the required Python packages
#
# @markdown Running this cell may require you to restart your session.

! pip install synthid-text[notebook]

from collections.abc import Sequence
import enum
import gc

import datasets
import huggingface_hub
from synthid_text import detector_mean
from synthid_text import logits_processing
from synthid_text import synthid_mixin
from synthid_text import detector_bayesian
import tensorflow as tf
import torch
import tqdm
import transformers

In [None]:
# @title Choose your model.
#
# @markdown This reference implementation is configured to use the Gemma v1.0
# @markdown Instruction-Tuned variants in 2B or 7B sizes, or GPT-2.


class ModelName(enum.Enum):
  GPT2 = 'gpt2'
  GEMMA_2B = 'google/gemma-2b-it'
  GEMMA_7B = 'google/gemma-7b-it'


model_name = 'google/gemma-7b-it' # @param ['gpt2', 'google/gemma-2b-it', 'google/gemma-7b-it']
MODEL_NAME = ModelName(model_name)

if MODEL_NAME is not ModelName.GPT2:
  huggingface_hub.notebook_login()

In [None]:
# @title Configure your device
#
# @markdown This notebook loads models from Hugging Face Transformers into the
# @markdown PyTorch deep learning runtime. PyTorch supports generation on CPU or
# @markdown GPU, but your chosen model will run best on the following hardware,
# @markdown some of which may require a
# @markdown [Colab Subscription](https://colab.research.google.com/signup).
# @markdown
# @markdown * Gemma v1.0 2B IT: Use a GPU with 16GB of memory, such as a T4.
# @markdown * Gemma v1.0 7B IT: Use a GPU with 32GB of memory, such as an A100.
# @markdown * GPT-2: Any runtime will work, though a High-RAM CPU or any GPU
# @markdown   will be faster.

DEVICE = (
    torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
)
DEVICE

In [None]:
# @title Example watermarking config
#
# @markdown SynthID Text produces unique watermarks given a configuration, with
# @markdown the most important piece of a configuration being the `keys`: a
# @markdown sequence of unique integers.
# @markdown
# @markdown This reference implementation uses a fixed watermarking
# @markdown configuration, which will be displayed when you run this cell.

CONFIG = synthid_mixin.DEFAULT_WATERMARKING_CONFIG
CONFIG

In [None]:
# @title Initialize the required constants, tokenizer, and logits processor

BATCH_SIZE = 8
NUM_BATCHES = 320
OUTPUTS_LEN = 1024
TEMPERATURE = 0.5
TOP_K = 40
TOP_P = 0.99

tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME.value)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

logits_processor = logits_processing.SynthIDLogitsProcessor(
    **CONFIG, top_k=TOP_K, temperature=TEMPERATURE
)

In [None]:
# @title Utility functions to load models, compute perplexity, and process prompts.


def load_model(
    model_name: ModelName,
    expected_device: torch.device,
    enable_watermarking: bool = False,
) -> transformers.PreTrainedModel:
  if model_name == ModelName.GPT2:
    model_cls = (
        synthid_mixin.SynthIDGPT2LMHeadModel
        if enable_watermarking
        else transformers.GPT2LMHeadModel
    )
    model = model_cls.from_pretrained(model_name.value, device_map='auto')
  else:
    model_cls = (
        synthid_mixin.SynthIDGemmaForCausalLM
        if enable_watermarking
        else transformers.GemmaForCausalLM
    )
    model = model_cls.from_pretrained(
        model_name.value,
        device_map='auto',
        torch_dtype=torch.bfloat16,
    )

  if str(model.device) != str(expected_device):
    raise ValueError('Model device not as expected.')

  return model


def _compute_perplexity(
    outputs: torch.LongTensor,
    scores: torch.FloatTensor,
    eos_token_mask: torch.LongTensor,
    watermarked: bool = False,
) -> float:
  """Compute perplexity given the model outputs and the logits."""
  len_offset = len(scores)
  if watermarked:
    nll_scores = scores
  else:
    nll_scores = [
        torch.gather(
            -torch.log(torch.nn.Softmax(dim=1)(sc)),
            1,
            outputs[:, -len_offset + idx, None],
        )
        for idx, sc in enumerate(scores)
    ]
  nll_sum = torch.nan_to_num(
      torch.squeeze(torch.stack(nll_scores, dim=1), dim=2)
      * eos_token_mask.long(),
      posinf=0,
  )
  nll_sum = nll_sum.sum(dim=1)
  nll_mean = nll_sum / eos_token_mask.sum(dim=1)
  return nll_mean.sum(dim=0)


def _process_raw_prompt(prompt: Sequence[str]) -> str:
  """Add chat template to the raw prompt."""
  if MODEL_NAME == ModelName.GPT2:
    return prompt.decode().strip('"')
  else:
    return tokenizer.apply_chat_template(
        [{'role': 'user', 'content': prompt.decode().strip('"')}],
        tokenize=False,
        add_generation_prompt=True,
    )

# 2. Applying a watermark

In [None]:
# @title Generate watermarked output

gc.collect()
torch.cuda.empty_cache()

batch_size = 1
example_inputs = [
    'I enjoy walking with my cute dog',
    'I am from New York',
    'The test was not so very hard after all',
    "I don't think they can score twice in so short a time",
]
example_inputs = example_inputs * (int(batch_size / 4) + 1)
example_inputs = example_inputs[:batch_size]

inputs = tokenizer(
    example_inputs,
    return_tensors='pt',
    padding=True,
).to(DEVICE)

model = load_model(MODEL_NAME, expected_device=DEVICE, enable_watermarking=True)
torch.manual_seed(0)
outputs = model.generate(
    **inputs,
    do_sample=True,
    temperature=0.7,
    max_length=1024,
    top_k=40,
)

print('Output:\n' + 100 * '-')
for i, output in enumerate(outputs):
  print(tokenizer.decode(output, skip_special_tokens=True))
  print(100 * '-')

del inputs, outputs, model
gc.collect()
torch.cuda.empty_cache()

## [Optional] Compare perplexity between watermarked and non-watermarked text

Sample [eli5 dataset](https://facebookresearch.github.io/ELI5/) outputs from
watermarked and non-watermarked models and verify that:

* The [perplexity](https://huggingface.co/docs/transformers/en/perplexity) of
  watermarked and non-watermarked text is similar.

$$\text{PPL}(X) = \exp \left\{ {-\frac{1}{t}\sum_i^t \log p_\theta (x_i|x_{<i}) } \right\}$$

In [None]:
# @title Load Eli5 dataset with HuggingFace datasets.

eli5_prompts = datasets.load_dataset("Pavithree/eli5")

In [None]:
# @title Non-watermarked output - perplexity
gc.collect()
torch.cuda.empty_cache()

model = load_model(MODEL_NAME, expected_device=DEVICE)
torch.manual_seed(0)

nonwm_g_values = []
nonwm_eos_masks = []
nonwm_outputs = []
perplexities = []

for batch_id in tqdm.tqdm(range(NUM_BATCHES)):
  prompts = eli5_prompts['train']['title'][
      batch_id * BATCH_SIZE:(batch_id + 1) * BATCH_SIZE]
  prompts = [_process_raw_prompt(prompt.encode()) for prompt in prompts]
  inputs = tokenizer(
      prompts,
      return_tensors='pt',
      padding=True,
  ).to(DEVICE)
  _, inputs_len = inputs['input_ids'].shape

  outputs = model.generate(
      **inputs,
      do_sample=True,
      max_length=inputs_len + OUTPUTS_LEN,
      temperature=TEMPERATURE,
      top_k=TOP_K,
      top_p=TOP_P,
      return_dict_in_generate=True,
      output_scores=True,
  )

  scores = outputs.scores
  outputs = outputs.sequences
  eos_token_mask = logits_processor.compute_eos_token_mask(
      input_ids=outputs[:, inputs_len:],
      eos_token_id=tokenizer.eos_token_id,
  )

  perplexities.append(_compute_perplexity(outputs, scores, eos_token_mask))

  g_values = logits_processor.compute_g_values(
      input_ids=outputs[:, inputs_len:],
  )

  nonwm_g_values.append(g_values.cpu())
  nonwm_eos_masks.append(eos_token_mask.cpu())
  nonwm_outputs.append(outputs.cpu())

  del inputs, prompts, eos_token_mask, g_values, outputs

del model, nonwm_g_values, nonwm_eos_masks, nonwm_outputs
gc.collect()
torch.cuda.empty_cache()

In [None]:
final_perplexity = torch.exp(np.sum(perplexities) / (BATCH_SIZE * NUM_BATCHES))
print(f"Perplexity of unwatermarked model: {final_perplexity}")

In [None]:
# @title Watermarked output - perplexity
gc.collect()
torch.cuda.empty_cache()

model = load_model(MODEL_NAME, expected_device=DEVICE, enable_watermarking=True)
torch.manual_seed(0)

wm_outputs = []
wm_g_values = []
wm_eos_masks = []
perplexities = []

for batch_id in tqdm.tqdm(range(NUM_BATCHES)):
  prompts = eli5_prompts['train']['title'][
      batch_id * BATCH_SIZE:(batch_id + 1) * BATCH_SIZE]
  prompts = [_process_raw_prompt(prompt.encode()) for prompt in prompts]
  inputs = tokenizer(
      prompts,
      return_tensors='pt',
      padding=True,
  ).to(DEVICE)
  _, inputs_len = inputs['input_ids'].shape

  outputs = model.generate(
      **inputs,
      do_sample=True,
      max_length=inputs_len + OUTPUTS_LEN,
      temperature=TEMPERATURE,
      top_k=TOP_K,
      top_p=TOP_P,
      return_dict_in_generate=True,
      output_scores=True,
  )
  scores = outputs.scores
  outputs = outputs.sequences

  # Mask to ignore all tokens after the end-of-sequence token.
  eos_token_mask = logits_processor.compute_eos_token_mask(
      input_ids=outputs[:, inputs_len:],
      eos_token_id=tokenizer.eos_token_id,
  )

  perplexities.append(_compute_perplexity(outputs, scores, eos_token_mask, watermarked=True))

  g_values = logits_processor.compute_g_values(
      input_ids=outputs[:, inputs_len:],
  )
  wm_outputs.append(outputs.cpu())
  wm_g_values.append(g_values.cpu())
  wm_eos_masks.append(eos_token_mask.cpu())

  del outputs, scores, inputs, prompts, eos_token_mask, g_values

del model, wm_outputs, wm_g_values, wm_eos_masks
gc.collect()
torch.cuda.empty_cache()

In [None]:
final_perplexity = torch.exp(
    torch.Tensor(np.sum(perplexities)) / (BATCH_SIZE * NUM_BATCHES)
)
print(f"Perplexity of watermarked model: {final_perplexity}")

# 3. Detecting a watermark

To detect the watermark, you have two options:
1.   Use the simple **Mean** scoring function. This can be done quickly and requires no training.
2.   Use the more powerful **Bayesian** scoring function. This requires training and takes more time.

For full explanation of these scoring functions, see the paper and its Supplementary Materials.


In [None]:
# @title Constants

NUM_NEGATIVES = 10000
POS_BATCH_SIZE = 32
NUM_POS_BATCHES = 313
NEG_BATCH_SIZE = 32
# Truncate outputs to this length for training.
POS_TRUNCATION_LENGTH = 200
NEG_TRUNCATION_LENGTH = 200
# Pad trucated outputs to this length for equal shape across all batches.
MAX_PADDED_LENGTH = 1000
TEMPERATURE = 1.0

In [None]:
# @title Generate model responses and compute g-values


def generate_responses(example_inputs, enable_watermarking):
  inputs = tokenizer(
      example_inputs,
      return_tensors='pt',
      padding=True,
  ).to(DEVICE)

  # @title Watermarked output preparation for detector training
  gc.collect()
  torch.cuda.empty_cache()

  model = load_model(
      MODEL_NAME,
      expected_device=DEVICE,
      enable_watermarking=enable_watermarking,
  )
  torch.manual_seed(0)
  _, inputs_len = inputs['input_ids'].shape

  outputs = model.generate(
      **inputs,
      do_sample=True,
      max_length=inputs_len + OUTPUTS_LEN,
      temperature=TEMPERATURE,
      top_k=TOP_K,
      top_p=TOP_P,
  )

  outputs = outputs[:, inputs_len:]

  # eos mask is computed, skip first ngram_len - 1 tokens
  # eos_mask will be of shape [batch_size, output_len]
  eos_token_mask = logits_processor.compute_eos_token_mask(
      input_ids=outputs,
      eos_token_id=tokenizer.eos_token_id,
  )[:, CONFIG['ngram_len'] - 1 :]

  # context repetition mask is computed
  context_repetition_mask = logits_processor.compute_context_repetition_mask(
      input_ids=outputs,
  )
  # context repitition mask shape [batch_size, output_len - (ngram_len - 1)]

  combined_mask = context_repetition_mask * eos_token_mask

  g_values = logits_processor.compute_g_values(
      input_ids=outputs,
  )
  # g values shape [batch_size, output_len - (ngram_len - 1), depth]

  return g_values, combined_mask


example_inputs = [
    'I enjoy walking with my cute dog',
    'I am from New York',
    'The test was not so very hard after all',
    "I don't think they can score twice in so short a time",
]

wm_g_values, wm_mask = generate_responses(
    example_inputs, enable_watermarking=True
)
uwm_g_values, uwm_mask = generate_responses(
    example_inputs, enable_watermarking=False
)

## Option 1: Mean detector

In [None]:
# @title Get Mean detector scores for the generated outputs.

# Watermarked responses tend to have higher Mean scores than unwatermarked
# responses. To classify responses you can set a score threshold, but this will
# depend on the distribution of scores for your use-case and your desired false
# positive / false negative rates.

wm_mean_scores = detector_mean.mean_score(
    wm_g_values.cpu().numpy(), wm_mask.cpu().numpy()
)
uwm_mean_scores = detector_mean.mean_score(
    uwm_g_values.cpu().numpy(), uwm_mask.cpu().numpy()
)

print('Mean scores for watermarked responses: ', wm_mean_scores)
print('Mean scores for unwatermarked responses: ', uwm_mean_scores)

# You may find that the Weighted Mean scoring function gives better
# classification performance than the Mean scoring function (in particular,
# higher scores for watermarked responses). See the paper for full details.

wm_weighted_mean_scores = detector_mean.weighted_mean_score(
    wm_g_values.cpu().numpy(), wm_mask.cpu().numpy()
)
uwm_weighted_mean_scores = detector_mean.weighted_mean_score(
    uwm_g_values.cpu().numpy(), uwm_mask.cpu().numpy()
)

print(
    'Weighted Mean scores for watermarked responses: ', wm_weighted_mean_scores
)
print(
    'Weighted Mean scores for unwatermarked responses: ',
    uwm_weighted_mean_scores,
)

## Option 2: Bayesian detector

In [None]:
# @title Generate watermarked samples for training Bayesian detector

gc.collect()
torch.cuda.empty_cache()

model = load_model(MODEL_NAME, expected_device=DEVICE, enable_watermarking=True)
torch.manual_seed(0)

eli5_prompts = datasets.load_dataset("Pavithree/eli5")

wm_outputs = []

for batch_id in tqdm.tqdm(range(NUM_POS_BATCHES)):
  prompts = eli5_prompts['train']['title'][
      batch_id * POS_BATCH_SIZE:(batch_id + 1) * POS_BATCH_SIZE]
  prompts = [_process_raw_prompt(prompt.encode()) for prompt in prompts]
  inputs = tokenizer(
      prompts,
      return_tensors='pt',
      padding=True,
  ).to(DEVICE)
  _, inputs_len = inputs['input_ids'].shape

  outputs = model.generate(
      **inputs,
      do_sample=True,
      max_length=inputs_len + OUTPUTS_LEN,
      temperature=TEMPERATURE,
      top_k=TOP_K,
      top_p=TOP_P,
  )

  wm_outputs.append(outputs[:, inputs_len:])

  del outputs, inputs, prompts

del model
gc.collect()
torch.cuda.empty_cache()

In [None]:
# @title Generate unwatermarked samples for training Bayesian detector

dataset, info = tfds.load('wikipedia/20230601.en', split='train', with_info=True)

dataset = dataset.take(10000)

# Convert the dataset to a DataFrame
df = tfds.as_dataframe(dataset, info)
ds = tf.data.Dataset.from_tensor_slices(dict(df))
tf.random.set_seed(0)
ds = ds.shuffle(buffer_size=10_000)
ds = ds.batch(batch_size=1)

tokenized_uwm_outputs = []
lengths = []
batched = []
# Pad to this length (on the right) for batching.
padded_length = 2500
for i, batch in tqdm.tqdm(enumerate(ds)):
  responses = [val.decode() for val in batch['text'].numpy()]
  inputs = tokenizer(
      responses,
      return_tensors='pt',
      padding=True,
  ).to(DEVICE)
  line = inputs['input_ids'].cpu().numpy()[0].tolist()
  if len(line) >= padded_length:
    line = line[:padded_length]
  else:
    line = line + [
        tokenizer.eos_token_id for _ in range(padded_length - len(line))
    ]
  batched.append(torch.tensor(line, dtype=torch.long, device=DEVICE)[None, :])
  if len(batched) == NEG_BATCH_SIZE:
    tokenized_uwm_outputs.append(torch.cat(batched, dim=0))
    batched = []
  if i > NUM_NEGATIVES:
    break

In [None]:
# @title Train the Bayesian detector
bayesian_detector, test_loss = (
    detector_bayesian.BayesianDetector.train_best_detector(
        tokenized_wm_outputs=wm_outputs,
        tokenized_uwm_outputs=tokenized_uwm_outputs,
        logits_processor=logits_processor,
        tokenizer=tokenizer,
        torch_device=DEVICE,
        max_padded_length=MAX_PADDED_LENGTH,
        pos_truncation_length=POS_TRUNCATION_LENGTH,
        neg_truncation_length=NEG_TRUNCATION_LENGTH,
        verbose=True,
        learning_rate=3e-3,
        n_epochs=100,
        l2_weights=np.zeros((1,)),
    )
)

In [None]:
# @title Get Bayesian detector scores for the generated outputs.

# Watermarked responses tend to have higher Bayesian scores than unwatermarked
# responses. To classify responses you can set a score threshold, but this will
# depend on the distribution of scores for your use-case and your desired false
# positive / false negative rates. See the paper for full details.

wm_bayesian_scores = bayesian_detector.score(
    wm_g_values.cpu().numpy(), wm_mask.cpu().numpy()
)
uwm_bayesian_scores = bayesian_detector.score(
    uwm_g_values.cpu().numpy(), uwm_mask.cpu().numpy()
)

print('Bayesian scores for watermarked responses: ', wm_bayesian_scores)
print('Bayesian scores for unwatermarked responses: ', uwm_bayesian_scores)