# EigenWatermark

In this notebook, we'll implement and evaluate a simplified version of the watermark of [Tai & Mansour (2019)](https://arxiv.org/abs/1903.08238).

## Google Colab Setup

The cells below handle installation and configuration for Google Colab environments.

In [None]:
# Check if running in Google Colab
try:
    import google.colab
    COLAB = True
    print("Google Colab runtime detected")
except ImportError:
    COLAB = False

# Mount Google Drive to allow for persistent storage (and avoid re-downloading
# code and data)
if COLAB:
    from google.colab import drive
    drive.mount('/content/drive')

In [None]:
%%bash

# If running in Colab:
if [[ -n "$COLAB_RELEASE_TAG" ]]; then
  BASE="/content/drive/MyDrive"
  REPO_DIR="$BASE/wm_tutorial"

  if [[ ! -d "$REPO_DIR" ]]; then
    echo "Repo not found — cloning and installing..."
    mkdir -p "$BASE"
    cd "$BASE" && git clone https://github.com/oreillyp/wm_tutorial.git
  else
    echo "Repo already exists — installing without cloning..."
  fi

  cd "$REPO_DIR" && pip install -e .
fi

In [None]:
# Make sure `wm_tutorial` is visible
if COLAB:
    import site
    site.main()

<img src="https://assets.amazon.science/dims4/default/4be5320/2147483647/strip/true/crop/750x250+0+0/resize/1200x400!/format/webp/quality/90/?url=http%3A%2F%2Famazon-topics-brightspot.s3.amazonaws.com%2Fscience%2Ff7%2Fc2%2F1bfe36f74e24a5c79dd83807c5b2%2Faudio-watermark.gif._CB468320145_.gif" width=600 height=600 />

Source: [Amazon](https://www.amazon.science/blog/audio-watermarking-algorithm-is-first-to-solve-second-screen-problem-in-real-time)

## Watermark Embedding

We'll start by implementing the watermark _embedding_ algorithm, which hides watermark data in an audio signal as follows:

1. Set constants `n_s` (the watermark subsequence length) and `n_r` (the number of times to repeat the subsequence)
2. Determine the parameters of a DCT transform that will convert our audio to a time-frequency representation for watermarking, as well as the indices of the frequency bins to watermark (and resulting number of bins `n_bins`)
3. Sample a pseudorandom binary (key) sequence `k` of length `n_s * n_r`
4. Sample `n_s` random orthonormal vectors of dimension `n_bins`, arrange them in a sequence, and repeat `n_r` times
5. Align vector sequence to DCT-transformed audio, scale via a `beta` paramater, and add
6. Take inverse DCT to obtain watermarked audio


### Set Embedding Parameters

In [None]:
import torch
import random
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
from audiotools import AudioSignal, STFTParams
from audiotools.data.preprocess import create_csv
from audiotools.data.transforms import BaseTransform, Compose, BackgroundNoise, RoomImpulseResponse

from wm_tutorial.constants import MANIFESTS_DIR, DATA_DIR, ASSETS_DIR
from wm_tutorial.util import dct, idct

NOISE_DIR = DATA_DIR / "noise-database" / "room"
RIR_DIR = DATA_DIR / "rir-database" / "real"

sample_rate = 16_000
window_length_ms = 10

stft_params = STFTParams(
    window_length=int(sample_rate * window_length_ms / 1000),
    hop_length=int(sample_rate * window_length_ms / 1000),     # No overlap
    window_type="boxcar",
)

n_s = 2
n_r = 100
beta = 0.75

f_min_hz = 3_000
f_max_hz = 4_000

# Compute number of embedding bins (for DCT, equal to window length)
nyquist = sample_rate / 2
total_bins = stft_params.window_length

f_min_bin = max(0, int(f_min_hz * total_bins / nyquist))
f_max_bin = min(total_bins, int(f_max_hz * total_bins / nyquist))

n_bins = f_max_bin - f_min_bin

# Because we construct orthogonal vectors via a square matrix, the number of 
# rows/columns we can take for our subsequence is at most the vector dimension
# n_bins
assert n_s <= n_bins

print(
    f"Embedding parameters: "
    f"subsequence length n_s={n_s}, subsequence repeats n_r={n_r}, "
    f"scale beta={beta :0.2f}, embedding bins n_bins={n_bins}"
)

### Load Audio

In [None]:
# Load an example audio file
signal = AudioSignal(ASSETS_DIR / "audio" / "bryan_0.wav").resample(sample_rate)
signal.stft_params = stft_params
signal.widget()

### Obtain Orthogonal Vectors & Key Sequence

In [None]:
def orthogonal_matrix(d: int, seed: int):
    """Create a d×d random orthonormal matrix."""
    assert d > 0

    g = torch.Generator().manual_seed(seed)
    A = torch.randn(d, d, generator=g, dtype=torch.float64)
    Q, R = torch.linalg.qr(A, mode='reduced')
    s = torch.sign(torch.diag(R))
    s[s == 0] = 1
    Q = Q * s
    return Q


def random_binary_sequence(l: int, seed: int):
    """Create a length-l random signed sequence."""
    assert l > 0
    
    g = torch.Generator().manual_seed(seed)
    return torch.randint(0, 2, (l,), generator=g, dtype=torch.int64)

In [None]:
# Sample key sequence
k = random_binary_sequence(n_s * n_r, seed=0)  # (n_s * n_r,)

# Create subsequence of n_s orthogonal vectors, each of dimension n_bins
m = orthogonal_matrix(n_bins, seed=0)  # (n_bins, n_bins)
subseq = m[:n_s]                       # (n_s, n_bins)

# Repeat n_r times
subseq = subseq.repeat([n_r, 1])       # (n_s * n_r, n_bins)

# Plot
plt.imshow(subseq.T[:, :n_s * 4], aspect="auto", origin="lower", interpolation="none")
plt.vlines(torch.arange(0, n_s * min(n_r, 4), n_s) - 0.5, ymin=-0.5, ymax=n_bins - 0.5, colors="white", linewidth=2)
plt.title("Repeated embedding vector sequence")
plt.xlabel("Frames")
plt.ylabel("Bins")
plt.show()

### Identify Embedding Region

In [None]:
# DCT spectrogram
spec = dct(signal)  # (n_batch, n_channels, n_freq, n_frames)
_spec = spec.mean(1)[0]  # (n_freq, n_frames)

# Embedding vectors
start_frame = 20
vecs = torch.zeros_like(_spec)
vecs[f_min_bin:f_max_bin, start_frame:start_frame + n_s * n_r] = subseq.T
fig, ax = plt.subplots()
ax.imshow((_spec**2).log1p() + vecs, origin="lower", aspect="auto", interpolation="none")

# Embedding frequency range
ax.axhline(f_min_bin, color='red', linewidth=2)
ax.axhline(f_max_bin, color='red', linewidth=2)

# Subsequence boundaries
subseq_boundaries = torch.arange(start_frame, start_frame + n_s * n_r + 1, n_s)
ax.vlines(subseq_boundaries, ymin=f_min_bin, ymax=f_max_bin, colors='white', linewidth=0.75)

# Plot
ax.set_title("Embedding region")
ax.set_xlabel("Frames")
ax.set_ylabel("Bins")
plt.show()

### Scale, Add, & Invert

In [None]:
# Scale relative to original signal magnitude in embedding region
orig_mag = _spec[f_min_bin:f_max_bin, start_frame:start_frame + n_s * n_r]  # (n_freq, n_s * n_r)
scale = beta * orig_mag.norm(dim=0, p=2, keepdim=True)                      # (n_freq, n_s * n_r)
scaled = scale * subseq.T                                                   # (n_freq, n_s * n_r)

# Apply flip
flipped = k.unsqueeze(0) * scaled                                           # (n_freq, n_s * n_r)

# Add to embed
spec[:, :, f_min_bin:f_max_bin, start_frame:start_frame + n_s * n_r] += flipped

# With watermark now embedded in DCT spectrogram we can invert to obtain 
# waveform audio
watermarked = idct(signal.clone(), spec)

watermarked.widget()

If we want to hear the watermark, we can simply take the difference between the watermarked and unwatermarked signals!

In [None]:
(watermarked - signal).widget()

In fact, that's how we settled on an embedding frequency range of 3kHz–4kHz: using the "before/after" audio examples provided by the authors for this [blog post](https://www.amazon.science/blog/audio-watermarking-algorithm-is-first-to-solve-second-screen-problem-in-real-time)!

In [None]:
before = AudioSignal(ASSETS_DIR / "audio" / "eigen_unwatermarked.wav")
after = AudioSignal(ASSETS_DIR / "audio" / "eigen_watermarked.wav")

before.widget()
after.widget()
(after - before).widget()

## Watermark Detection

Now that we've embedded our watermark in an audio signal, how can we detect it? We'll adapt the "self-correlation" test proposed by Tai & Mansour, leveraging our knowledge of the secret key vector `k` to undo random flips. Detection operates as follows:

1. Take the DCT spectrogram of the audio recording and isolate the subband spanning our embedding region of `f_min_hz` to `f_max_hz`
2. For each possible watermark starting frame, take a spectrogram segment of the watermark length `n_s * n_r`and apply the flips specified by our key
3. Divide the segment into `n_r` subsequences of length `n_s` and perform "self-correlation" scoring for each possible pair of subsequences
4. Sub all self-correlation scores and move on to the next segment by advancing one frame; return all scores when finished

In [None]:
def detect(
    s: AudioSignal, 
    sample_rate: int,
    stft_params: STFTParams,
    f_min_hz: float,
    f_max_hz: float,
    n_s: int,
    n_r: int,
    k: torch.Tensor,
):
    # Isolate embedding band from spectrogram
    signal = s.clone().resample(sample_rate)
    signal.stft_params = stft_params
    spec = dct(signal)                        # (n_batch, n_channels, n_freq, n_frames)
    _spec = spec.mean(1)[0]                   # (n_freq, n_frames)
    
    nyquist = sample_rate / 2
    total_bins = stft_params.window_length

    f_min_bin = max(0, int(f_min_hz * total_bins / nyquist))
    f_max_bin = min(total_bins, int(f_max_hz * total_bins / nyquist))

    band = _spec[f_min_bin:f_max_bin]  # (n_bins, n_frames)
    
    # For every length-(n_s * n_r) segment of embedding band, apply flips 
    # specified by key, correlate each pair of length-n_s flipped subsequences,
    # and store sum
    sums = []
    for i in range(band.shape[-1] - (n_s * n_r) + 1):

        # Select length-(n_s * n_r) segment
        selected = band[:, i:i + n_s * n_r].clone()
        assert selected.shape[-1] == n_s * n_r

        # Apply flips, turning binary sequence into sign sequence
        selected = selected * (2 * k.float().unsqueeze(0) - 1)

        # "Fold" to separate subsequences
        folded = selected.reshape(selected.shape[0], n_r, n_s)  # (n_bins, n_r, n_s)
        folded = folded.permute(2, 1, 0)  # (n_s, n_r, n_bins)

        # Self-correlation
        normalized = folded / folded.norm(dim=2, keepdim=True).clamp_min(1e-8)
        sim = torch.bmm(normalized, normalized.transpose(1, 2)).sum(dim=0)

        # Zero redundant self-correlations
        sim = torch.triu(sim, diagonal=1)

        sums += [sim.sum().item() / (n_r * (n_r - 1) // 2)]
    
    return sums

Let's run our detection algorithm on watermarked and unwatermarked audio and see if we observe a difference in scores!

In [None]:
# Run detection on unwatermarked audio
scores_unwatermarked = detect(
    signal,
    sample_rate,
    stft_params,
    f_min_hz,
    f_max_hz,
    n_s,
    n_r,
    k
)

# Run detection on watermarked audio 
scores_watermarked = detect(
    watermarked,
    sample_rate,
    stft_params,
    f_min_hz,
    f_max_hz,
    n_s,
    n_r,
    k
)

# Plot detection scores
plt.plot(scores_unwatermarked, label="Unwatermarked", linewidth=1, color="green")
plt.hlines(max(scores_unwatermarked), xmin=0, xmax=len(scores_unwatermarked), linestyle="--", color="green", alpha=0.5)
plt.plot(scores_watermarked, label="Watermarked", linewidth=1, color="red")
plt.hlines(max(scores_watermarked), xmin=0, xmax=len(scores_watermarked), linestyle="--", color="red", alpha=0.5)
plt.xlabel("Position")
plt.ylabel("Self-correlation score")
plt.legend()
plt.show()

Indeed, there appears to be a large enough "score gap" to let us discriminate between watermarked and unwatermarked audio!

Similar to detection, we can wrap our embedding algorithm in a single function for convenience.

In [None]:
def embed(
    s: AudioSignal, 
    sample_rate: int,
    stft_params: STFTParams,
    f_min_hz: float,
    f_max_hz: float,
    n_s: int,
    n_r: int,
    beta: float,
    seed: int,
):
    
    # Isolate embedding band from spectrogram
    signal = s.clone().resample(sample_rate)
    signal.stft_params = stft_params
    spec = dct(signal)                        # (n_batch, n_channels, n_freq, n_frames)

    nyquist = sample_rate / 2
    total_bins = stft_params.window_length

    f_min_bin = max(0, int(f_min_hz * total_bins / nyquist))
    f_max_bin = min(total_bins, int(f_max_hz * total_bins / nyquist))
    n_bins = f_max_bin - f_min_bin
    L = n_s * n_r

    # Sample key sequence
    k = random_binary_sequence(L, seed=seed)  # (n_s * n_r,)

    # Create subsequence of n_s orthogonal vectors (each dim n_bins), then repeat n_r times
    m = orthogonal_matrix(n_bins, seed=0)     # (n_bins, n_bins)
    subseq = m[:n_s]                          # (n_s, n_bins)
    subseq = subseq.repeat([n_r, 1])          # (n_s * n_r, n_bins)

    # Choose start frame (maximum-energy subsequence)
    band_all = spec.mean(1)[0][f_min_bin:f_max_bin, :]    # (n_bins, T)
    T = band_all.shape[-1]
    if T < L:
        raise ValueError(f"Not enough frames ({T}) for watermark length n_s * n_r = {L}.")

    # Per-frame strength
    strength = (band_all ** 2).sum(dim=0).to(dtype=torch.float32)  # (T,)

    # Rolling sum over window length L
    kernel = torch.ones(1, 1, L, device=strength.device, dtype=strength.dtype)
    window_sums = torch.nn.functional.conv1d(
        strength.view(1, 1, -1), 
        kernel
    )  # (1, 1, T - L + 1)
    st = int(torch.argmax(window_sums).item())

    # Scale per frame and embed
    band = band_all[:, st:st + L]                                 # (n_bins, L)
    per_frame_scale = beta * band.norm(dim=0, p=2, keepdim=True)  # (1, L)
    scaled = subseq.T.to(band.dtype) * per_frame_scale            # (n_bins, L)

    # Apply flips: turn {0,1} into {-1,+1}
    k_tensor = torch.as_tensor(k, device=band.device, dtype=band.dtype)  # (L,)
    sign = (2.0 * k_tensor - 1.0).unsqueeze(0)                           # (1, L)
    flipped = scaled * sign                                              # (n_bins, L)

    # Add to embed (broadcast over batch/channels)
    spec[:, :, f_min_bin:f_max_bin, st:st + L] += flipped.unsqueeze(0).unsqueeze(0)

    # Invert
    signal = idct(signal, spec)

    # Restore original sample rate and STFT params
    signal = signal.resample(s.sample_rate)
    signal.stft_params = s.stft_params

    # Ensure length does not change
    signal.audio_data = signal.audio_data[..., :s.shape[-1]]
    signal.audio_data = torch.nn.functional.pad(signal.audio_data, (0, max(0, s.shape[-1] - signal.shape[-1])))

    return signal, k

## Evaluating Watermark Performance

Now that we've got a functioning audio watermark, we'll want to evaluate its performance. In this case, we care about (1) our ability to discriminate between watermarked and unwatermarked audio under adverse conditions -- i.e., __robustness__, and (2) the degree to which our watermark preserves audio quality -- i.e., __perceptual transparency__.

We'll start by implementing a simple measure of discrimination performance: the achievable true-positive detection rate given a fixed allowable false-positive detection rate, __TPR@FPR__. This reflects the fact that false positives are often costly for real-world watermarking schemes operating at scale (e.g. flagging potential deepfake videos on a social media platform for human review). Ideally, we want to maintain very high true-positive rates even at small (<1%) false-positive rates.

In [None]:
def tpr_at_fpr(scores_true, scores_false, target_fpr: float):
    """
    Compute achievable True Positive Rate (TPR) at a given False Positive Rate (FPR).
    """
    # Convert to tensors
    s_true  = torch.as_tensor(scores_true, dtype=torch.float32)
    s_false = torch.as_tensor(scores_false, dtype=torch.float32)

    # Concatenate scores and labels
    scores = torch.cat([s_true, s_false])
    labels = torch.cat([
        torch.ones_like(s_true, dtype=torch.int32),
        torch.zeros_like(s_false, dtype=torch.int32)
    ])

    # Sort scores descending
    sorted_scores, idx = torch.sort(scores, descending=True)
    sorted_labels = labels[idx]

    # Cumulative counts
    tp_cum = torch.cumsum(sorted_labels, dim=0)
    fp_cum = torch.cumsum(1 - sorted_labels, dim=0)

    # Totals
    tp_total = s_true.numel()
    fp_total = s_false.numel()

    # Compute TPR and FPR
    tpr = tp_cum.float() / tp_total
    fpr = fp_cum.float() / fp_total

    # Mask for achievable FPR
    mask = fpr <= target_fpr
    return tpr[mask].max().item() if mask.any() else 0.0

Now, we'll embed our watermark repeatedly in our tiny dataset of speech recordings (selecting a random embedding position, vector sequence, and key each time). We will then run the detection algorithm on both watermarked and unwatermarked recordings, taking the maximum detection score per recording. Finally, we will check our achievable true-positive detection rates at fixed false-positive rates.

In [None]:
n_examples = 500

recordings = [ASSETS_DIR / "audio" / f"bryan_{i}.wav" for i in range(4)]

scores_watermarked = []
scores_unwatermarked = []

for i in tqdm(range(n_examples)):

    # Select random recording
    signal = AudioSignal(random.choice(recordings)).resample(sample_rate)

    # Embed watermark
    watermarked, k = embed(
        signal.clone(),
        sample_rate,
        stft_params,
        f_min_hz,
        f_max_hz,
        n_s,
        n_r,
        beta,
        seed=i
    )

    # Detect watermark
    scores_watermarked += [
        max(
            detect(
                watermarked,
                sample_rate,
                stft_params,
                f_min_hz,
                f_max_hz,
                n_s,
                n_r,
                k
            )
        )
    ]
    scores_unwatermarked += [
        max(
            detect(
                signal,
                sample_rate,
                stft_params,
                f_min_hz,
                f_max_hz,
                n_s,
                n_r,
                k
            )
        )
    ]


print(
    f"TPR @ 10% FPR: {tpr_at_fpr(scores_watermarked, scores_unwatermarked, 0.10) :0.2f}\n"
    f"TPR @ 1% FPR: {tpr_at_fpr(scores_watermarked, scores_unwatermarked, 0.01) :0.2f}\n"
    f"TPR @ 0.1% FPR: {tpr_at_fpr(scores_watermarked, scores_unwatermarked, 0.001) :0.2f}\n"
)

Not too shabby! But how can we measure the perceptual transparency of our watermark?

The simplest metrics quantify the distance between a watermarked signal and a clean reference signal of the same length (i.e. the unwatermarked audio). These signal-level metrics can give us a good idea as to the "magnitude" of our watermark, but do not necessarily align well with human perception. Our first metric, __SNR__, measures the magnitude of the difference between the input and reference waveforms relative to the reference waveform.

In [None]:
def snr(x: AudioSignal, ref: AudioSignal, eps: float = 1e-10):
    """Signal-to-Noise Ratio: 10 * log10( ||ref||^2 / ||x - ref||^2 )"""

    assert x.sample_rate == ref.sample_rate
    assert x.shape == ref.shape

    x = x.audio_data.to(torch.float64)
    ref = ref.audio_data.to(torch.float64)
    
    noise = x - ref
    num = (ref ** 2).sum(dim=-1)               # (n_batch, n_channels)
    den = (noise ** 2).sum(dim=-1) + eps
    
    snr_val = 10.0 * torch.log10(torch.clamp(num, min=eps) / den)
    snr_val = snr_val.to(torch.float32)
    return snr_val.mean(dim=-1)

In [None]:
# Sample signal
signal = AudioSignal(random.choice(recordings)).resample(sample_rate)

# Embed watermark
watermarked, k = embed(
    signal.clone(),
    sample_rate,
    stft_params,
    f_min_hz,
    f_max_hz,
    n_s,
    n_r,
    beta,
    seed=i
)

print(f"SNR: {snr(watermarked, signal).item() :0.2f}")

While SNR's simplicity makes it interpretable and easy to implement, it has a few clear issues. One issue is its sensitivity to scaling: if two recordings differ only in their "volume," even slightly, we might obtain a small SNR value despite their perceptual similarity. For example, SNR "thinks" that these two recordings...

In [None]:
scaled = signal * 0.75

print(f"SNR: {snr(scaled, signal).item() :0.2f}")

scaled.widget()
signal.widget()

... are about as different as these two recordings:

In [None]:
noise_level = snr(scaled, signal).item()
noise_audio = signal.clone()
noise_audio.audio_data = torch.randn_like(noise_audio.audio_data)
noisy = signal.clone().mix(noise_audio, noise_level)

print(f"SNR: {snr(noisy, signal).item() :0.2f}")

signal.widget()
noisy.widget()

This particular failure mode was addressed by [Le Roux et al.](https://arxiv.org/pdf/1811.02508), who proposed the __SI-SDR__ metric to account for the effect of scale differences on the related SDR (signal-to-distortion ratio) metric.

In [None]:
def si_sdr(x: torch.Tensor, ref: torch.Tensor, zero_mean: bool = True, eps: float = 1e-10):
    """
    Scale-Invariant SDR (Le Roux et al., 2019):
        Let x̄, s̄ be (optionally) zero-mean versions of estimate and reference.
        alpha = <x̄, s̄> / ||s̄||^2
        s_target = alpha * s̄
        e = x̄ - s_target
        SI-SDR = 10 * log10( ||s_target||^2 / ||e||^2 )
    """
    assert x.sample_rate == ref.sample_rate
    assert x.shape == ref.shape

    x = x.audio_data.to(torch.float64)
    ref = ref.audio_data.to(torch.float64)

    if zero_mean:
        # Subtract mean over time per batch and channel index
        x = x - x.mean(dim=-1, keepdim=True)
        ref = ref - ref.mean(dim=-1, keepdim=True)

    # Project input onto reference
    ref_energy = (ref ** 2).sum(dim=-1, keepdim=True)  # (n_batch, n_channels, 1)
    
    # Avoid division by zero if reference is silent
    alpha = (x * ref).sum(dim=-1, keepdim=True) / (ref_energy + eps)
    s_target = alpha * ref
    e = x - s_target

    num = (s_target ** 2).sum(dim=-1)           # (n_batch, n_channels)
    den = (e ** 2).sum(dim=-1) + eps
    si_sdr_val = 10.0 * torch.log10(torch.clamp(num, min=eps) / den)
    si_sdr_val = si_sdr_val.to(torch.float32)
    return si_sdr_val.mean(dim=-1)

Here, we can see that SI-SDR isn't tripped up by scale differences -- we get a much larger value (in deciBels) than SNR for the same scaled recording pair, indicating that the "difference" is minimal.

In [None]:
scaled = signal * 0.75

print(f"SI-SDR: {si_sdr(scaled, signal).item() :0.2f}")

There are plenty of more sophisticated signal-level metrics that attempt to mimic human perception of differences between audio signals, such as [__PESQ__](https://en.wikipedia.org/wiki/Perceptual_Evaluation_of_Speech_Quality) and [__STOI__](https://ieeexplore.ieee.org/document/5713237). However, for watermarking purposes, it's crucial to understand that _existing signal-level metrics are misaligned with human perception under many circumstances, and should not be treated as a "gold standard" by which to measure the perceptual transparency of watermarks!_

For example, SNR, SI-SDR, PESQ, STOI, and other metrics would all consider the following pair of signals to be very different because they are _slightly misaligned in time_. Do they sound different to you?

In [None]:
shifted = signal.clone()
shifted.audio_data = torch.roll(shifted.audio_data, shifts=100, dims=-1)

print(f"SNR: {snr(shifted, signal).item() :0.2f}")
print(f"SI-SDR: {si_sdr(shifted, signal).item() :0.2f}")

signal.widget()
shifted.widget()

Where does this leave us? Signal-level metrics can still be helpful for getting a rough idea of how an audio watermark manifests -- for instance, a high SNR / SI-SDR value indicates that a watermark is low in "magnitude" relative to the host signal. There are also a number of [recent](https://arxiv.org/abs/2505.20741) [metrics](https://arxiv.org/abs/2304.01448) that do not require a precisely time-aligned reference signal, or even a reference signal at all. For researchers developing novel watermarking methods, using a diverse array of metrics (and clearly understanding exactly what each measures!) can help avoid overfitting to a single target.

Ultimately though, the best test of perceptual transparency is a carefully conducted human listening evaluation in a standard format like [ABX](https://en.wikipedia.org/wiki/ABX_test) or [MUSHRA](https://en.wikipedia.org/wiki/MUSHRA). If we want to design an human-imperceptible watermark, there is no better judge than a human!

### Watermark Robustness

We've touched on perceptual transparency and measured our watermark's detection performance in a "clean" setting. But how will this watermark fare in the real world, where audio is often in less-than-pristine condition by the time it reaches a watermark detector? To put EigenWatermark to the test, __we'll simulate two common distortions__ -- additive noise and reverberation -- and see whether our detection performance falls off when watermarked audio is modified.

Luckily for us, the `audiotools` library comes with built-in implementations of many standard audio transformations. For noise and reverberation in particular, we'll use datasets of real-world recorded background noise and room impulse responses, respectively, to model these distortions.

In [None]:
# AudioTools noise & reverb transforms require that noise and impulse response
# recordings, respectively, be provided via filepaths in a .csv. We can create
# these .csv files from our datasets using a built-in utility function
create_csv(
    audio_files=list((Path(NOISE_DIR)).rglob("*.wav")),
    output_csv=MANIFESTS_DIR / "noise_room.csv", 
)
create_csv(
    audio_files=list((Path(RIR_DIR)).rglob("*.wav")),
    output_csv=MANIFESTS_DIR / "rir_real.csv",
)

Each transformation is initialized with both fixed and randomized parameters. Randomized parameters allow for realizing a large number of "versions" of the transformation -- e.g., with different background noise levels or room reverberation characteristics. We specify allowable values for these parameters via a tuple indicating the sampling stategy (e.g. `"uniform"`, `"constant"`, `"choice"`) and the allowable values/ranges.

In [None]:
noise = BackgroundNoise(
    snr=("uniform", 20.0, 30.0),  # Sample noise level uniformly in [10, 30]dB
    sources=[MANIFESTS_DIR / "noise_room.csv"],
    eq_amount=("const", 1.0),     # Sample EQ level as a fixed value of 1.0
    n_bands=3,
    prob=1.0,
    loudness_cutoff=None,
)

reverb = RoomImpulseResponse(
    drr=("uniform", 10.0, 30.0),   # Sample reverb direct-reverberant ratio uniformly in [0, 30]dB
    sources=[MANIFESTS_DIR / "rir_real.csv"],
    eq_amount=("const", 1.0),     # Sample EQ level as a fixed value of 1.0
    n_bands=6,
    prob=1.0,
    use_original_phase=False,
    offset=0.0,
    duration=1.0,
)

# We can combine multiple transforms in sequence!
noise_and_reverb = Compose(noise, reverb)

We can then apply a transformation to audio in two steps:

1. Sample randomized parameters via `.instantiate()` (passing the signal to be transformed and a random seed) to obtain a dictionary that deterministically specifies the sampled transformation
2. Pass the signal and dictionary of sampled parameters  to `.transform()` to apply the transformation

In [None]:
signal.widget()

noise_kwargs = noise.instantiate(1, signal)
print("Sampled noise parameters:", noise_kwargs)
out_noise = noise.transform(signal.clone(), **noise_kwargs)
out_noise.widget()

reverb_kwargs = reverb.instantiate(1, signal)
print("Sampled reverb parameters:", reverb_kwargs)
out_reverb = reverb.transform(signal.clone(), **reverb_kwargs)
out_reverb.widget()

both_kwargs = noise_and_reverb.instantiate(1, signal)
print("Sampled noise+reverb parameters:", both_kwargs)
out_both = noise_and_reverb.transform(signal.clone(), **both_kwargs)
out_both.widget()

Now, let's re-run our evaluation, but apply random noise and reverb to each recording before performing detection.

In [None]:
n_examples = 500

recordings = [ASSETS_DIR / "audio" / f"bryan_{i}.wav" for i in range(4)]

scores_watermarked = []
scores_unwatermarked = []

for i in tqdm(range(n_examples)):

    # Select random recording
    signal = AudioSignal(random.choice(recordings)).resample(sample_rate)
    
    # Embed watermark
    watermarked, k = embed(
        signal.clone(),
        sample_rate,
        stft_params,
        f_min_hz,
        f_max_hz,
        n_s,
        n_r,
        beta=beta,
        seed=i
    )

    # Apply random transformation to both unwatermarked and watermarked signal
    tfm_kwargs = noise_and_reverb.instantiate(1, signal)
    tfm_signal = noise_and_reverb.transform(signal.clone(), **tfm_kwargs)
    tfm_watermarked = noise_and_reverb.transform(watermarked.clone(), **tfm_kwargs)
    
    # Detect watermark
    scores_watermarked += [
        max(
            detect(
                tfm_watermarked,
                sample_rate,
                stft_params,
                f_min_hz,
                f_max_hz,
                n_s,
                n_r,
                k
            )
        )
    ]
    scores_unwatermarked += [
        max(
            detect(
                tfm_signal,
                sample_rate,
                stft_params,
                f_min_hz,
                f_max_hz,
                n_s,
                n_r,
                k
            )
        )
    ]


print(
    f"TPR @ 10% FPR: {tpr_at_fpr(scores_watermarked, scores_unwatermarked, 0.10) :0.2f}\n"
    f"TPR @ 1% FPR: {tpr_at_fpr(scores_watermarked, scores_unwatermarked, 0.01) :0.2f}\n"
    f"TPR @ 0.1% FPR: {tpr_at_fpr(scores_watermarked, scores_unwatermarked, 0.001) :0.2f}\n"
)

And look at that -- our watermark maintains strong performance even under simulated acoustic distortions! This isn't just luck. Tai & Mansour explicitly designed their watermark for speech in noisy, reverberant environments. Key choices include:
* Embedding the watermark in a narrow frequency band (3-4kHz) that is often pronounced speech recordings
* Leveraging _self-correlation_ rather than _cross-correlation_ to embed and detect the watermark, as stationary additive noise and time-invariant filters (which reverberation can be modeled as) do not interfere with spectral self-similarity between different segments of the signal.

EigenWatermark was originally proposed to prevent Amazon Alexa from "waking" when the name appeared in television commercials. If the watermark were embedded in the audio track of a commercial, with the watermark key distributed to Alexa devices, they could be configured to prevent "waking" when the watermark was detected. However, there's nothing to stop this watermarking approach from [also being applied to synthetic speech identification](https://interactiveaudiolab.github.io/assets/papers/oreilly_jin_su_pardo_watermark.pdf).

However, EigenWatermark isn't universally robust. Let's try a couple of audio transformations the original authors didn't consider due to their irrelevance in the wakeword-detection scenario: spectral gating and speed change.

In [None]:
from audiotools.core.util import sample_from_dist
from audiotools.data.transforms import MaskLowMagnitudes
from numpy.random import RandomState

class Speed(BaseTransform):

    def __init__(
        self, 
        factor: tuple = ("choice", (0.99, 1.01)),
        name: str = None,
        prob: float = 1.0,
    ):
        super().__init__(name=name, prob=prob)
        
        self.factor = factor

    def _instantiate(self, state: RandomState):

        factor = sample_from_dist(self.factor, state)
        return {"factor": factor}

    def _transform(self, signal, factor):

        if isinstance(factor, (float, int)):
            factor = [factor]

        out = signal.clone()
        
        for i, _factor in enumerate(factor):

            src_rate = int(factor * signal.sample_rate)
            tgt_rate = int(signal.sample_rate)
            
            # Keep GCD of source and target sample rates reasonably large to 
            # limit resampling kernel size
            assert not tgt_rate % 50
            src_rate = (src_rate // 50 * 50)

            _out = out[i].clone()
            _out.sample_rate = src_rate
            _out = _out.resample(tgt_rate)
            _len = min(out.shape[-1], _out.shape[-1])
            out.audio_data[i, :, :_len] = _out.audio_data[..., :_len]

        return out

In [None]:
spectralgate = MaskLowMagnitudes(
    db_cutoff=("const", -20),
    prob=1.0,
)
speed = Speed()
attack = Compose(speed, spectralgate)


tfm_kwargs = spectralgate.instantiate(0, signal)
out = spectralgate.transform(signal.clone(), **tfm_kwargs)
out.widget()


tfm_kwargs = speed.instantiate(0, signal)
out = speed.transform(signal.clone(), **tfm_kwargs)
out.widget()

In [None]:
n_examples = 500

recordings = [ASSETS_DIR / "audio" / f"bryan_{i}.wav" for i in range(4)]

scores_watermarked = []
scores_unwatermarked = []

for i in tqdm(range(n_examples)):

    # Select random recording
    signal = AudioSignal(random.choice(recordings)).resample(sample_rate)
    
    # Embed watermark
    watermarked, k = embed(
        signal.clone(),
        sample_rate,
        stft_params,
        f_min_hz,
        f_max_hz,
        n_s,
        n_r,
        beta=beta,
        seed=i
    )

    # Apply random transformation to both unwatermarked and watermarked signal
    tfm_kwargs = attack.instantiate(1, signal)
    tfm_signal = attack.transform(signal.clone(), **tfm_kwargs)
    tfm_watermarked = attack.transform(watermarked.clone(), **tfm_kwargs)
    
    # Detect watermark
    scores_watermarked += [
        max(
            detect(
                tfm_watermarked,
                sample_rate,
                stft_params,
                f_min_hz,
                f_max_hz,
                n_s,
                n_r,
                k
            )
        )
    ]
    scores_unwatermarked += [
        max(
            detect(
                tfm_signal,
                sample_rate,
                stft_params,
                f_min_hz,
                f_max_hz,
                n_s,
                n_r,
                k
            )
        )
    ]


print(
    f"TPR @ 10% FPR: {tpr_at_fpr(scores_watermarked, scores_unwatermarked, 0.10) :0.2f}\n"
    f"TPR @ 1% FPR: {tpr_at_fpr(scores_watermarked, scores_unwatermarked, 0.01) :0.2f}\n"
    f"TPR @ 0.1% FPR: {tpr_at_fpr(scores_watermarked, scores_unwatermarked, 0.001) :0.2f}\n"
)

And at last, we've found the limit of this watermark's robustness. What does it take to break it? Let's have a listen.

In [None]:
signal.widget()
watermarked.widget()
tfm_signal.widget()
tfm_watermarked.widget()