# SynthID Text: Watermarking for Generated Text

This notebook demonstrates how to use the [SynthID Text library][synthid-code]
to apply and detect watermarks on generated text. It is divided into three major
sections and intended to be run end-to-end.

1.  **_Setup_**: Importing the SynthID Text library, choosing your model (either
    [Gemma][gemma] or [GPT-2][gpt2]) and device (either CPU or GPU, depending
    on your runtime), defining the watermarking configuration, and initializing
    some helper functions.
1.  **_Applying a watermark_**: Loading your selected model using the
    [Hugging Face Transformers][transformers] library, using that model to
    generate some watermarked text, and comparing the perplexity of the
    watermarked text to that of text generated by the base model.
1.  **_Detecting a watermark_**: Training a detector to recognize text generated
    with a specific watermarking configuration, and then using that detector to
    predict whether a set of examples were generated with that configuration.

As the reference implementation for the
[SynthID Text paper in _Nature_][synthid-paper], this library and notebook are
intended for research review and reproduction only. They should not be used in
production systems. For a production-grade implementation, check out the
official SynthID logits processor in [Hugging Face Transformers][transformers].

[gemma]: https://ai.google.dev/gemma/docs/model_card
[gpt2]: https://huggingface.co/openai-community/gpt2
[synthid-code]: https://github.com/google-deepmind/synthid-text
[synthid-paper]: https://www.nature.com/
[transformers]: https://huggingface.co/docs/transformers/en/index

# 1. Setup

In [None]:
!pip install googletrans



In [None]:
# @title Install and import the required Python packages
#
# @markdown Running this cell may require you to restart your session.

! pip install synthid-text[notebook]

from collections.abc import Sequence
import enum
import gc

import datasets
import huggingface_hub
from synthid_text import detector_mean
from synthid_text import logits_processing
from synthid_text import synthid_mixin
from synthid_text import detector_bayesian
import tensorflow as tf
import torch
import tqdm
import transformers



In [None]:
# @title Choose your model.
#
# @markdown This reference implementation is configured to use the Gemma v1.0
# @markdown Instruction-Tuned variants in 2B or 7B sizes, or GPT-2.


class ModelName(enum.Enum):
  GPT2 = 'gpt2'
  GEMMA_2B = 'google/gemma-2b-it'
  GEMMA_7B = 'google/gemma-7b-it'


model_name = 'google/gemma-2b-it' # @param ['gpt2', 'google/gemma-2b-it', 'google/gemma-7b-it']
MODEL_NAME = ModelName(model_name)

if MODEL_NAME is not ModelName.GPT2:
  huggingface_hub.notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
# @title Configure your device
#
# @markdown This notebook loads models from Hugging Face Transformers into the
# @markdown PyTorch deep learning runtime. PyTorch supports generation on CPU or
# @markdown GPU, but your chosen model will run best on the following hardware,
# @markdown some of which may require a
# @markdown [Colab Subscription](https://colab.research.google.com/signup).
# @markdown
# @markdown * Gemma v1.0 2B IT: Use a GPU with 16GB of memory, such as a T4.
# @markdown * Gemma v1.0 7B IT: Use a GPU with 32GB of memory, such as an A100.
# @markdown * GPT-2: Any runtime will work, though a High-RAM CPU or any GPU
# @markdown   will be faster.

DEVICE = (
    torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
)
DEVICE

device(type='cuda', index=0)

In [None]:
# @title Example watermarking config
#
# @markdown SynthID Text produces unique watermarks given a configuration, with
# @markdown the most important piece of a configuration being the `keys`: a
# @markdown sequence of unique integers.
# @markdown
# @markdown This reference implementation uses a fixed watermarking
# @markdown configuration, which will be displayed when you run this cell.

CONFIG = synthid_mixin.DEFAULT_WATERMARKING_CONFIG
CONFIG

immutabledict({'ngram_len': 5, 'keys': [654, 400, 836, 123, 340, 443, 597, 160, 57, 29, 590, 639, 13, 715, 468, 990, 966, 226, 324, 585, 118, 504, 421, 521, 129, 669, 732, 225, 90, 960], 'sampling_table_size': 65536, 'sampling_table_seed': 0, 'context_history_size': 1024, 'device': device(type='cuda', index=0)})

In [None]:
# @title Initialize the required constants, tokenizer, and logits processor

BATCH_SIZE = 8
NUM_BATCHES = 320
OUTPUTS_LEN = 1024
TEMPERATURE = 0.5
TOP_K = 40
TOP_P = 0.99

tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME.value)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

logits_processor = logits_processing.SynthIDLogitsProcessor(
    **CONFIG, top_k=TOP_K, temperature=TEMPERATURE
)

In [None]:
# @title Utility functions to load models, compute perplexity, and process prompts.


def load_model(
    model_name: ModelName,
    expected_device: torch.device,
    enable_watermarking: bool = False,
) -> transformers.PreTrainedModel:
  match model_name:
    case ModelName.GPT2:
      model_cls = (
          synthid_mixin.SynthIDGPT2LMHeadModel
          if enable_watermarking
          else transformers.GPT2LMHeadModel
      )
      model = model_cls.from_pretrained(model_name.value, device_map='auto')
    case ModelName.GEMMA_2B | ModelName.GEMMA_7B:
      model_cls = (
          synthid_mixin.SynthIDGemmaForCausalLM
          if enable_watermarking
          else transformers.GemmaForCausalLM
      )
      model = model_cls.from_pretrained(
          model_name.value,
          device_map='auto',
          torch_dtype=torch.bfloat16,
      )

  if model.device != expected_device:
    raise ValueError('Model device not as expected.')
  return model


def _compute_perplexity(
    outputs: torch.LongTensor,
    scores: torch.FloatTensor,
    eos_token_mask: torch.LongTensor,
    watermarked: bool = False,
) -> float:
  """Compute perplexity given the model outputs and the logits."""
  len_offset = len(scores)
  if watermarked:
    nll_scores = scores
  else:
    nll_scores = [
        torch.gather(
            -torch.log(torch.nn.Softmax(dim=1)(sc)),
            1,
            outputs[:, -len_offset + idx, None],
        )
        for idx, sc in enumerate(scores)
    ]
  nll_sum = torch.nan_to_num(
      torch.squeeze(torch.stack(nll_scores, dim=1), dim=2)
      * eos_token_mask.long(),
      posinf=0,
  )
  nll_sum = nll_sum.sum(dim=1)
  nll_mean = nll_sum / eos_token_mask.sum(dim=1)
  return nll_mean.sum(dim=0)


def _process_raw_prompt(prompt: Sequence[str]) -> str:
  """Add chat template to the raw prompt."""
  match MODEL_NAME:
    case ModelName.GPT2:
      return prompt.decode().strip('"')
    case ModelName.GEMMA_2B | ModelName.GEMMA_7B:
      return tokenizer.apply_chat_template(
          [{'role': 'user', 'content': prompt.decode().strip('"')}],
          tokenize=False,
          add_generation_prompt=True,
      )

# 2. Applying a watermark

In [None]:
# @title Generate watermarked output

gc.collect()
torch.cuda.empty_cache()

batch_size = 1
example_inputs = [
    'Once upon a time in Carnegie Mellon Univesity, a group of students came up with a brilliant idea.',
]
example_inputs = example_inputs * (int(batch_size / 4) + 1)
example_inputs = example_inputs[:batch_size]

inputs = tokenizer(
    example_inputs,
    return_tensors='pt',
    padding=True,
).to(DEVICE)

model = load_model(MODEL_NAME, expected_device=DEVICE, enable_watermarking=True)
torch.manual_seed(0)
outputs = model.generate(
    **inputs,
    do_sample=True,
    temperature=0.7,
    max_length=1024,
    top_k=40,
)

print('Output:\n' + 100 * '-')
for i, output in enumerate(outputs):
  print(tokenizer.decode(output, skip_special_tokens=True))
  print(100 * '-')

del inputs, outputs, model
gc.collect()
torch.cuda.empty_cache()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  batched_outputs = func(*batched_inputs, **kwargs)


# 3. Detecting a watermark

To detect the watermark, you have two options:
1.   Use the simple **Mean** scoring function. This can be done quickly and requires no training.
2.   Use the more powerful **Bayesian** scoring function. This requires training and takes more time.

For full explanation of these scoring functions, see the paper and its Supplementary Materials.


In [None]:
# @title Constants

NUM_NEGATIVES = 10000
POS_BATCH_SIZE = 32
NUM_POS_BATCHES = 313
NEG_BATCH_SIZE = 32
# Truncate outputs to this length for training.
POS_TRUNCATION_LENGTH = 400
NEG_TRUNCATION_LENGTH = 400
# Pad trucated outputs to this length for equal shape across all batches.
MAX_PADDED_LENGTH = 1000
TEMPERATURE = 1.0

In [None]:
# @title Generate model responses and compute g-values


def generate_responses(example_inputs, enable_watermarking):
  inputs = tokenizer(
      example_inputs,
      return_tensors='pt',
      padding=True,
  ).to(DEVICE)

  # @title Watermarked output preparation for detector training
  gc.collect()
  torch.cuda.empty_cache()

  model = load_model(
      MODEL_NAME,
      expected_device=DEVICE,
      enable_watermarking=enable_watermarking,
  )
  torch.manual_seed(0)
  _, inputs_len = inputs['input_ids'].shape

  outputs = model.generate(
      **inputs,
      do_sample=True,
      max_length=inputs_len + OUTPUTS_LEN,
      temperature=TEMPERATURE,
      top_k=TOP_K,
      top_p=TOP_P,
  )

  outputs = outputs[:, inputs_len:]

  # print(25 * '-' + 'OUTPUT' + 25 * '-')
  # print(tokenizer.decode(outputs[0], skip_special_tokens=True))

  # eos mask is computed, skip first ngram_len - 1 tokens
  # eos_mask will be of shape [batch_size, output_len]
  eos_token_mask = logits_processor.compute_eos_token_mask(
      input_ids=outputs,
      eos_token_id=tokenizer.eos_token_id,
  )[:, CONFIG['ngram_len'] - 1 :]

  # context repetition mask is computed
  context_repetition_mask = logits_processor.compute_context_repetition_mask(
      input_ids=outputs,
  )
  # context repitition mask shape [batch_size, output_len - (ngram_len - 1)]

  combined_mask = context_repetition_mask * eos_token_mask

  g_values = logits_processor.compute_g_values(
      input_ids=outputs,
  )
  # g values shape [batch_size, output_len - (ngram_len - 1), depth]

  return outputs, g_values, combined_mask


example_inputs = [
    'Once upon a time in Carnegie Mellon Univesity, a group of students came up with a brilliant idea.',
]

# wm_g_values, wm_mask = generate_responses(
#     example_inputs, enable_watermarking=True
# )
# uwm_g_values, uwm_mask = generate_responses(
#     example_inputs, enable_watermarking=False
# )

## Translation attack

In [None]:
import random
import googletrans
from googletrans import Translator
from typing import List, Tuple
translator = Translator()

In [None]:
def get_available_languages() -> List[str]:
    """Get a list of language codes supported by googletrans."""
    # A subset of common languages to keep translations meaningful
    return ['es', 'fr', 'de', 'it', 'pt', 'ru', 'ja', 'ko', 'zh-cn', 'ar', 'cs', 'so']

In [None]:
prompts = [
  "Write a paragraph about climate change",
  "Describe a beautiful sunset",
  "Explain how the internet works",
  "Tell me about the importance of exercise",
  "Tell me about Carnegie Mellon",
  "What do you think the future looks like",
  "How can I make a delicious pizza is a few different ways",
  "What goes into making a computer",
  "Tell me about the Roman Empire",
  "Tel me how large language models work",
  "How does DNA sequencing work",
]

In [None]:
def translate_text_chain(text: str, num_translations: int) -> Tuple[str, List[str]]:
    """Translate text through a chain of random languages and back to English."""
    translator = Translator()
    languages = get_available_languages()
    translation_path = []
    current_text = text

    for _ in range(num_translations):
        target_lang = random.choice(languages)
        translation_path.append(target_lang)
        try:
            current_text = translator.translate(current_text, dest=target_lang).text
            # Translate back to English
            if _ == num_translations - 1:
                current_text = translator.translate(current_text, dest='en').text
        except Exception as e:
            print(f"Translation error: {e}")
            break

    return current_text, translation_path

In [None]:
def analyze_watermark_translation_attack(
    prompt: str,
    num_translation_variants: int = 3,
    max_translations_per_variant: int = 3
) -> dict:
    """Analyze how watermark detection scores change after translation attacks on the response."""

    # Generate initial watermarked response
    wm_outputs, wm_g_values, wm_mask = generate_responses([prompt], enable_watermarking=True)
    uwm_outputs, uwm_g_values, uwm_mask = generate_responses([prompt], enable_watermarking=False)

    # Decode the watermarked response
    watermarked_response = tokenizer.decode(wm_outputs[0], skip_special_tokens=True)

    # Get initial scores
    original_wm_score = detector_mean.mean_score(wm_g_values.cpu().numpy(), wm_mask.cpu().numpy())
    original_uwm_score = detector_mean.mean_score(uwm_g_values.cpu().numpy(), uwm_mask.cpu().numpy())

    results = {
        'prompt': prompt,
        'original_response': watermarked_response,
        'original_watermarked_score': float(original_wm_score[0]),
        'original_unwatermarked_score': float(original_uwm_score[0]),
        'translation_variants': []
    }

    # Generate multiple translation variants of the watermarked response
    for i in range(num_translation_variants):
        num_translations = random.randint(1, max_translations_per_variant)
        translated_text, translation_path = translate_text_chain(watermarked_response, num_translations)

        # Tokenize translated text
        translated_tokens = tokenizer(
            translated_text,
            return_tensors='pt',
            padding=True
        ).input_ids.to(DEVICE)

        # Compute g-values and mask for translated text
        trans_g_values = logits_processor.compute_g_values(
            input_ids=translated_tokens
        )
        trans_mask = logits_processor.compute_eos_token_mask(
            input_ids=translated_tokens,
            eos_token_id=tokenizer.eos_token_id
        )[:, CONFIG['ngram_len'] - 1:]

        # Get score for translated version
        trans_mean_scores = detector_mean.mean_score(
            trans_g_values.cpu().numpy(),
            trans_mask.cpu().numpy()
        )

        variant_result = {
            'variant_id': i + 1,
            'num_translations': num_translations,
            'translation_path': translation_path,
            'translated_text': translated_text,
            'watermark_score': float(trans_mean_scores[0])
        }

        results['translation_variants'].append(variant_result)

    return results

In [None]:
# Test prompt
prompt = "Explain the importance of renewable energy in modern society"

# Analyze watermark robustness under translation
results = analyze_watermark_translation_attack(
    prompt,
    num_translation_variants=3,
    max_translations_per_variant=3
)

# Print results
print(f"Prompt: {results['prompt']}\n")
print(f"Original response:\n{results['original_response']}\n")
print(f"Original watermarked score: {results['original_watermarked_score']:.4f}")
print(f"Original unwatermarked score: {results['original_unwatermarked_score']:.4f}\n")

for variant in results['translation_variants']:
    print(f"Variant {variant['variant_id']}:")
    print(f"Number of translations: {variant['num_translations']}")
    print(f"Translation path: {' -> '.join(variant['translation_path'])} -> en")
    print(f"Watermark score: {variant['watermark_score']:.4f}")
    print(f"Translated text:\n{variant['translated_text']}\n")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Any
import random
import seaborn as sns
from collections import defaultdict
import os

In [None]:
def run_translation_analysis(prompts: List[str], max_translations: int = 5, trials_per_length: int = 3) -> Dict[str, Any]:
    """Run translation analysis on multiple prompts with varying translation chain lengths."""
    results = {}

    for prompt in prompts:
        # Generate initial watermarked response
        wm_outputs, wm_g_values, wm_mask = generate_responses([prompt], enable_watermarking=True)
        watermarked_response = tokenizer.decode(wm_outputs[0], skip_special_tokens=True)

        # Get original scores
        original_mean_score = float(detector_mean.mean_score(
            wm_g_values.cpu().numpy(),
            wm_mask.cpu().numpy()
        )[0])
        original_weighted_mean_score = float(detector_mean.weighted_mean_score(
            wm_g_values.cpu().numpy(),
            wm_mask.cpu().numpy()
        )[0])

        # Store results for this prompt
        prompt_results = {
            'prompt': prompt,
            'num_tokens': len(tokenizer.encode(watermarked_response)),
            'original_response': watermarked_response,
            'original_mean_score': original_mean_score,
            'original_weighted_mean_score': original_weighted_mean_score,
            'translation_results': defaultdict(list)  # Keyed by num_translations
        }

        # Test different translation chain lengths
        for num_translations in range(1, max_translations + 1):
            # Multiple trials per translation length
            for _ in range(trials_per_length):
                translated_text, translation_path = translate_text_chain(
                    watermarked_response,
                    num_translations
                )

                # Tokenize translated text
                translated_tokens = tokenizer(
                    translated_text,
                    return_tensors='pt',
                    padding=True
                ).input_ids.to(DEVICE)

                # Compute g-values and mask for translated text
                trans_g_values = logits_processor.compute_g_values(
                    input_ids=translated_tokens
                )
                trans_mask = logits_processor.compute_eos_token_mask(
                    input_ids=translated_tokens,
                    eos_token_id=tokenizer.eos_token_id
                )[:, CONFIG['ngram_len'] - 1:]

                # Get both types of scores
                mean_score = float(detector_mean.mean_score(
                    trans_g_values.cpu().numpy(),
                    trans_mask.cpu().numpy()
                )[0])
                weighted_mean_score = float(detector_mean.weighted_mean_score(
                    trans_g_values.cpu().numpy(),
                    trans_mask.cpu().numpy()
                )[0])

                # Store results
                prompt_results['translation_results'][num_translations].append({
                    'mean_score': mean_score,
                    'weighted_mean_score': weighted_mean_score,
                    'translation_path': translation_path
                })

        results[prompt] = prompt_results

    return results

In [None]:
def plot_translation_analysis(results: Dict[str, Any], output_dir: str = 'plots'):
    """Create plots for each prompt showing how scores change with translation chain length."""
    os.makedirs(output_dir, exist_ok=True)

    for prompt_idx, (prompt, prompt_results) in enumerate(results.items()):
        # Create figure with two subplots
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

        # Extract data for plotting
        num_translations = sorted(prompt_results['translation_results'].keys())
        mean_scores = []
        weighted_mean_scores = []
        mean_stds = []
        weighted_mean_stds = []

        for n in num_translations:
            trials = prompt_results['translation_results'][n]
            mean_scores.append(np.mean([t['mean_score'] for t in trials]))
            weighted_mean_scores.append(np.mean([t['weighted_mean_score'] for t in trials]))
            mean_stds.append(np.std([t['mean_score'] for t in trials]))
            weighted_mean_stds.append(np.std([t['weighted_mean_score'] for t in trials]))

        # Plot Mean Scores
        ax1.errorbar(num_translations, mean_scores, yerr=mean_stds,
                    marker='o', linestyle='-', capsize=5)
        ax1.axhline(y=prompt_results['original_mean_score'], color='r',
                   linestyle='--', label='Original Score')
        ax1.set_xlabel('Number of Translations')
        ax1.set_ylabel('Mean Score')
        ax1.set_title(f'Unweighted Mean Score vs Translations\n{prompt_results["num_tokens"]} tokens')
        ax1.legend()
        ax1.grid(True)

        # Plot Weighted Mean Scores
        ax2.errorbar(num_translations, weighted_mean_scores, yerr=weighted_mean_stds,
                    marker='o', linestyle='-', capsize=5)
        ax2.axhline(y=prompt_results['original_weighted_mean_score'], color='r',
                    linestyle='--', label='Original Score')
        ax2.set_xlabel('Number of Translations')
        ax2.set_ylabel('Weighted Mean Score')
        ax2.set_title(f'Weighted Mean Score vs Translations\n{prompt_results["num_tokens"]} tokens')
        ax2.legend()
        ax2.grid(True)

        # Add prompt as overall title
        truncated_prompt = prompt[:50] + '...' if len(prompt) > 50 else prompt
        fig.suptitle(f'Prompt {prompt_idx + 1}: {truncated_prompt}', fontsize=12)

        # Adjust layout and save
        plt.tight_layout()
        plt.savefig(f'{output_dir}/prompt_{prompt_idx + 1}_analysis.png', dpi=300, bbox_inches='tight')
        plt.close()

In [None]:
import pickle
with open('results.pkl', 'wb') as f:
    pickle.dump(results, f)

In [None]:
prompts = [
    "Describe how the process of fermentation works in making bread.",
    "Explain the difference between renewable and non-renewable energy sources.",
    "How do electric cars work and what are their benefits?",
    "What are the effects of deforestation on the environment?",
    "Describe the journey of a drop of blood through the human body.",
    "What are the major components of a computer network?",
    "How does machine learning differ from traditional programming?",
    "Explain how the stock market operates.",
    "What are the most common mental health disorders and their symptoms?",
    "Describe the process of cell division (mitosis and meiosis).",
    "What are the primary causes of inflation in an economy?",
    "How do plants adapt to different environments?",
    "What are the key principles of quantum mechanics?",
    "Explain the significance of the Industrial Revolution.",
]

# Run analysis
print("Running translation analysis...")
results = run_translation_analysis(
    prompts=prompts,
    max_translations=5,
    trials_per_length=50
)

# Create plots
print("Creating plots...")
plot_translation_analysis(results)

print("Analysis complete! Plots have been saved to the 'plots' directory.")

In [None]:
print(list(results.values())[1]['original_response'])



**Renewable energy sources**

* Sustainable and replenishable
* Examples include solar energy, wind energy, hydro power, and biomass

**Non-renewable energy sources**

* Are finite and will eventually run out
* Examples include fossil fuels, coal, oil, and natural gas

Sure. Here is the difference between renewable and non-renewable energy sources:

**Renewable energy sources** are sources of energy that can be replaced naturally on a human timescale. This means that they are not finite and will eventually run out, unlike non-renewable energy sources.

**Non-renewable energy sources** are sources of energy that are not replenished naturally on a human timescale. This means that they are finite and will eventually run out.

**Here is a table summarizing the key differences between renewable and non-renewable energy sources:**

| Feature | Renewable Energy | Non-Renewable Energy |
|---|---|---|
| Sustainability | Sustainable | Finite |
| Replenishment | Natural | Not natural |
| Exampl

In [None]:
current_text, _ = translate_text_chain(list(results.values())[1]['original_response'], 5)
print(current_text)

** Renewable energy sources **

* Durable and interchangeable
* Examples are solar energy, wind energy, hydroelectricity and the mass of life

** Energy sources are not renewable **

* Finally and they will go at some point
* Examples are fossil fuels, coal, oil and natural gas

Admittedly, the difference between renewable and non -renewable energy sources:

** Renewable energy sources ** are energy sources that can be replaced by natural time in human times.

** Non -renewable energy sources ** are energy sources that are generally not renovated in a humanitarian period, which means that they are finally and start at some point.

** Here is a calendar in which the most important differences between renewable and non -perpetitive energy sources are summarized: **

|
|
|
|
|
|

Renewable energy sources have a much lower environmental effect than non -renewable energy sources.

In addition, renewable energy sources have become increasingly effective, making it a more attractive option fo

In [None]:
def create_summary_plots(results: Dict[str, Any], output_dir: str = 'plots'):
    """Create summary plots showing averaged results across all prompts."""

    # Initialize data structures for collecting statistics
    max_translations = max(
        max(prompt_results['translation_results'].keys())
        for prompt_results in results.values()
    )

    # For each number of translations, collect all scores across prompts
    mean_scores_by_translations = defaultdict(list)
    weighted_mean_scores_by_translations = defaultdict(list)
    original_mean_scores = []
    original_weighted_mean_scores = []

    # Collect all scores
    for prompt_results in results.values():
        original_mean_scores.append(prompt_results['original_mean_score'])
        original_weighted_mean_scores.append(prompt_results['original_weighted_mean_score'])

        for num_trans, trials in prompt_results['translation_results'].items():
            mean_scores_by_translations[num_trans].extend(
                trial['mean_score'] for trial in trials
            )
            weighted_mean_scores_by_translations[num_trans].extend(
                trial['weighted_mean_score'] for trial in trials
            )

    # Calculate statistics
    num_translations = sorted(mean_scores_by_translations.keys())

    mean_scores_avg = [np.mean(mean_scores_by_translations[n]) for n in num_translations]
    mean_scores_std = [np.std(mean_scores_by_translations[n]) for n in num_translations]

    weighted_mean_scores_avg = [np.mean(weighted_mean_scores_by_translations[n]) for n in num_translations]
    weighted_mean_scores_std = [np.std(weighted_mean_scores_by_translations[n]) for n in num_translations]

    original_mean_score_avg = np.mean(original_mean_scores)
    original_mean_score_std = np.std(original_mean_scores)

    original_weighted_mean_score_avg = np.mean(original_weighted_mean_scores)
    original_weighted_mean_score_std = np.std(original_weighted_mean_scores)

    # Create summary plots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

    # Plot Mean Scores
    ax1.errorbar(num_translations, mean_scores_avg, yerr=mean_scores_std,
                marker='o', linestyle='-', capsize=5, label='Translation Score')
    ax1.axhline(y=original_mean_score_avg, color='r', linestyle='--',
                label=f'Original Score (μ={original_mean_score_avg:.3f}, σ={original_mean_score_std:.3f})')

    # Add shaded region for original score standard deviation
    ax1.axhspan(
        original_mean_score_avg - original_mean_score_std,
        original_mean_score_avg + original_mean_score_std,
        color='r', alpha=0.1
    )

    ax1.set_xlabel('Number of Translations')
    ax1.set_ylabel('Mean Score')
    ax1.set_title('Average Unweighted Mean Score vs Translations\nAcross All Prompts')
    ax1.legend()
    ax1.grid(True)

    # Plot Weighted Mean Scores
    ax2.errorbar(num_translations, weighted_mean_scores_avg, yerr=weighted_mean_scores_std,
                marker='o', linestyle='-', capsize=5, label='Translation Score')
    ax2.axhline(y=original_weighted_mean_score_avg, color='r', linestyle='--',
                label=f'Original Score (μ={original_weighted_mean_score_avg:.3f}, σ={original_weighted_mean_score_std:.3f})')

    # Add shaded region for original score standard deviation
    ax2.axhspan(
        original_weighted_mean_score_avg - original_weighted_mean_score_std,
        original_weighted_mean_score_avg + original_weighted_mean_score_std,
        color='r', alpha=0.1
    )

    ax2.set_xlabel('Number of Translations')
    ax2.set_ylabel('Weighted Mean Score')
    ax2.set_title('Average Weighted Mean Score vs Translations\nAcross All Prompts')
    ax2.legend()
    ax2.grid(True)

    # Add overall title
    num_prompts = len(results)
    total_trials = sum(len(trials) for prompt_results in results.values()
                      for trials in prompt_results['translation_results'].values())

    fig.suptitle(
        f'Summary of Watermark Translation Analysis\n'
        f'{num_prompts} Prompts, {total_trials} Total Translation Trials',
        fontsize=14
    )

    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(f'{output_dir}/summary_analysis.png', dpi=300, bbox_inches='tight')
    plt.close()

In [None]:
create_summary_plots(results)

## Paraphrase attack (using the same LLM)


## Option 1: Mean detector

In [None]:
# @title Get Mean detector scores for the generated outputs.

# Watermarked responses tend to have higher Mean scores than unwatermarked
# responses. To classify responses you can set a score threshold, but this will
# depend on the distribution of scores for your use-case and your desired false
# positive / false negative rates.

wm_mean_scores = detector_mean.mean_score(
    wm_g_values.cpu().numpy(), wm_mask.cpu().numpy()
)
uwm_mean_scores = detector_mean.mean_score(
    uwm_g_values.cpu().numpy(), uwm_mask.cpu().numpy()
)

print('Mean scores for watermarked responses: ', wm_mean_scores)
print('Mean scores for unwatermarked responses: ', uwm_mean_scores)

# You may find that the Weighted Mean scoring function gives better
# classification performance than the Mean scoring function (in particular,
# higher scores for watermarked responses). See the paper for full details.

wm_weighted_mean_scores = detector_mean.weighted_mean_score(
    wm_g_values.cpu().numpy(), wm_mask.cpu().numpy()
)
uwm_weighted_mean_scores = detector_mean.weighted_mean_score(
    uwm_g_values.cpu().numpy(), uwm_mask.cpu().numpy()
)

print(
    'Weighted Mean scores for watermarked responses: ', wm_weighted_mean_scores
)
print(
    'Weighted Mean scores for unwatermarked responses: ',
    uwm_weighted_mean_scores,
)

Mean scores for watermarked responses:  [0.5353511]
Mean scores for unwatermarked responses:  [0.4979452]
Weighted Mean scores for watermarked responses:  [0.5483145]
Weighted Mean scores for unwatermarked responses:  [0.5022802]


## Option 2: Bayesian detector

In [None]:
# @title Generate watermarked samples for training Bayesian detector

gc.collect()
torch.cuda.empty_cache()

model = load_model(MODEL_NAME, expected_device=DEVICE, enable_watermarking=True)
torch.manual_seed(0)

eli5_prompts = datasets.load_dataset("Pavithree/eli5")

wm_outputs = []

for batch_id in tqdm.tqdm(range(NUM_POS_BATCHES)):
  prompts = eli5_prompts['train']['title'][
      batch_id * POS_BATCH_SIZE:(batch_id + 1) * POS_BATCH_SIZE]
  prompts = [_process_raw_prompt(prompt.encode()) for prompt in prompts]
  inputs = tokenizer(
      prompts,
      return_tensors='pt',
      padding=True,
  ).to(DEVICE)
  _, inputs_len = inputs['input_ids'].shape

  outputs = model.generate(
      **inputs,
      do_sample=True,
      max_length=inputs_len + OUTPUTS_LEN,
      temperature=TEMPERATURE,
      top_k=TOP_K,
      top_p=TOP_P,
  )

  wm_outputs.append(outputs[:, inputs_len:])

  del outputs, inputs, prompts

del model
gc.collect()
torch.cuda.empty_cache()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

README.md:   0%|          | 0.00/78.0 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


train.json:   0%|          | 0.00/656M [00:00<?, ?B/s]

validation.json:   0%|          | 0.00/17.3M [00:00<?, ?B/s]

test.json:   0%|          | 0.00/30.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/216147 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3020 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

  0%|          | 0/313 [00:03<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 986.00 MiB. GPU 0 has a total capacity of 14.75 GiB of which 853.06 MiB is free. Process 2415 has 13.91 GiB memory in use. Of the allocated memory 4.72 GiB is allocated by PyTorch, and 96.89 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# @title Generate unwatermarked samples for training Bayesian detector

dataset, info = tfds.load('wikipedia/20230601.en', split='train', with_info=True)

dataset = dataset.take(10000)

# Convert the dataset to a DataFrame
df = tfds.as_dataframe(dataset, info)
ds = tf.data.Dataset.from_tensor_slices(dict(df))
tf.random.set_seed(0)
ds = ds.shuffle(buffer_size=10_000)
ds = ds.batch(batch_size=1)

tokenized_uwm_outputs = []
lengths = []
batched = []
# Pad to this length (on the right) for batching.
padded_length = 2500
for i, batch in tqdm.tqdm(enumerate(ds)):
  responses = [val.decode() for val in batch['text'].numpy()]
  inputs = tokenizer(
      responses,
      return_tensors='pt',
      padding=True,
  ).to(DEVICE)
  line = inputs['input_ids'].cpu().numpy()[0].tolist()
  if len(line) >= padded_length:
    line = line[:padded_length]
  else:
    line = line + [
        tokenizer.eos_token_id for _ in range(padded_length - len(line))
    ]
  batched.append(torch.tensor(line, dtype=torch.long, device=DEVICE)[None, :])
  if len(batched) == NEG_BATCH_SIZE:
    tokenized_uwm_outputs.append(torch.cat(batched, dim=0))
    batched = []
  if i > NUM_NEGATIVES:
    break

In [None]:
# @title Train the Bayesian detector
bayesian_detector, test_loss = (
    detector_bayesian.BayesianDetector.train_best_detector(
        tokenized_wm_outputs=wm_outputs,
        tokenized_uwm_outputs=tokenized_uwm_outputs,
        logits_processor=logits_processor,
        tokenizer=tokenizer,
        torch_device=DEVICE,
        max_padded_length=MAX_PADDED_LENGTH,
        pos_truncation_length=POS_TRUNCATION_LENGTH,
        neg_truncation_length=NEG_TRUNCATION_LENGTH,
        verbose=True,
        learning_rate=3e-3,
        n_epochs=100,
        l2_weights=np.zeros((1,)),
    )
)

In [None]:
# @title Get Bayesian detector scores for the generated outputs.

# Watermarked responses tend to have higher Bayesian scores than unwatermarked
# responses. To classify responses you can set a score threshold, but this will
# depend on the distribution of scores for your use-case and your desired false
# positive / false negative rates. See the paper for full details.

wm_bayesian_scores = bayesian_detector.score(
    wm_g_values.cpu().numpy(), wm_mask.cpu().numpy()
)
uwm_bayesian_scores = bayesian_detector.score(
    uwm_g_values.cpu().numpy(), uwm_mask.cpu().numpy()
)

print('Bayesian scores for watermarked responses: ', wm_bayesian_scores)
print('Bayesian scores for unwatermarked responses: ', uwm_bayesian_scores)