# End-to-End Neural Network Watermark

In this notebook, we'll implement and evaluate a simplified end-to-end neural network-based audio watermarking system. To emphasize the generality of the principles at play here, we'll define a "vanilla" transformer-based architecture rather than trying to adopt the architecture of a specific published watermarking system.

## Google Colab Setup

The cells below handle installation and configuration for Google Colab environments.

In [None]:
# Check if running in Google Colab
try:
    import google.colab
    COLAB = True
    print("Google Colab runtime detected")
except ImportError:
    COLAB = False

# Mount Google Drive to allow for persistent storage (and avoid re-downloading
# code and data)
if COLAB:
    from google.colab import drive
    drive.mount('/content/drive')

In [None]:
%%bash

# If running in Colab:
if [[ -n "$COLAB_RELEASE_TAG" ]]; then
  BASE="/content/drive/MyDrive"
  REPO_DIR="$BASE/wm_tutorial"

  if [[ ! -d "$REPO_DIR" ]]; then
    echo "Repo not found — cloning and installing..."
    mkdir -p "$BASE"
    cd "$BASE" && git clone https://github.com/oreillyp/wm_tutorial.git
  else
    echo "Repo already exists — installing without cloning..."
  fi

  cd "$REPO_DIR" && pip install -e .
fi

In [None]:
# Make sure `wm_tutorial` is visible
if COLAB:
    import site
    site.main()

## Transformers for Sequence Processing

In [None]:
import random
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader

from audiotools import AudioSignal, STFTParams
from audiotools.data.datasets import AudioLoader, AudioDataset
from audiotools.data.preprocess import create_csv
from audiotools.data.transforms import BaseTransform, Compose, BackgroundNoise, RoomImpulseResponse, MaskLowMagnitudes, Equalizer

from wm_tutorial.constants import DATA_DIR, ASSETS_DIR, MANIFESTS_DIR
from wm_tutorial.util import count_parameters, collate, tpr_at_fpr, snr, si_sdr
from wm_tutorial.nn.transformer import Transformer
from wm_tutorial.nn.message import MessageEmbedding, MessageBlock
from wm_tutorial.tfm import Noise, Reverb, Speed

NOISE_DIR = DATA_DIR / "noise-database" / "room"
RIR_DIR = DATA_DIR / "rir-database" / "real"

sample_rate = 16_000
window_length_ms = 50
max_len_s = 5.0
n_train_steps = 10_000
batch_size = 32  # Adjust as needed to account for your GPU memory!

# Watermark message length, i.e. capacity
n_bits = 16

# Check if GPU available
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# STFT parameters
stft_params = STFTParams(
    window_length=int(sample_rate * window_length_ms / 1000),
    hop_length=int(sample_rate * window_length_ms * 0.75 / 1000),
    window_type="tukey",
)

# Transformer config
config = {
    "model_dim": 256,
    "num_heads": 8,
    "num_layers": 8,
    "bias": False,
    "dropout": 0.1,
    "max_len": int(max_len_s * sample_rate / stft_params.hop_length) + 2,
    "dim_feedforward": 1024,
    "pos_enc": "absolute",
}

# Initialize a small transformer
tfmr = Transformer(**config)
count_parameters(tfmr)

Our transformer processes sequential data of shape `(n_batch, seq_len, model_dim)`, producing outputs of the same shape as its inputs. Ideally, this processing extracts/contextualizes/enriches information within the sequence relevant to our training task. For the purposes of this tutorial, we won't be diving too deep into the transformer architecture, and will be treating it as a "black-box" building block for our watermark embedding and detection algorithms.

In [None]:
# Process a random input sequence
x = torch.randn(1, config["max_len"], config["model_dim"])
out = tfmr(x, mask=None)

print(f"Input shape: {x.shape}, output shape: {out.shape}")

## Representing Watermark Messages

In end-to-end neural network-based watermarking systems, watermark keys or messages typically take the form of fixed-length binary vectors. Because these systems operate by predicting a message directly from audio -- i.e., in a _steganographic_ manner -- the message size is essentially the system's information __capacity__. For instance, a system capable of embedding length-16 binary vectors in audio and reconstructing them at detection time has a 16-bit capacity.

How do we pass this message to our watermark embedding network? Most neural network architectures are designed to operate on _continuous_ vector representations of data, not booleans. We can take a page out of the book of large language models, which are tasked with representing sequences of text "tokens" (in reality, integer indices into a vocabulary). What if we did something similar with binary vectors?

In [None]:
# Sample a watermark message (binary vector)
msg = torch.randint(0, 2, (n_bits,))

# Create an "embedding" lookup table that maps a binary value {0, 1} to one of two corresponding continuous vectors
emb_table = torch.nn.Embedding(2, config["model_dim"])

embedded = emb_table(msg)

print(
    f"Original message: {msg.tolist()}\n"
    f"Message tensor shape: {msg.shape}\n"
    f"Embedded message tensor shape: {embedded.shape}\n"
)

plt.imshow(
    msg.unsqueeze(0).repeat(config["model_dim"], 1), 
    origin="lower", aspect="auto", interpolation="none"
)
plt.title("Binary message")
plt.yticks([])
plt.xlabel("Message sequence idx")
plt.show()

plt.imshow(
    embedded.detach().T, 
    origin="lower", aspect="auto", interpolation="none"
)
plt.title("Embedded message")
plt.xlabel("Message sequence idx")
plt.ylabel("Channel dim")
plt.show()

Now our message is a sequence of continuous vectors, which is a good first step. But how do we distinguish between a `0` in the first position and a `0` in the tenth position? Both vectors look the same -- there is no explicit encoding of positional information in this representation! While some architectures (e.g. convolutional neural networks) will impart the required positional information automatically, this isn't the case with transformers.

One solution is to just apply standard positional encoding (e.g. sinusoidal, RoPE) to our embedded message sequence. Another approach, used by Fernandez et al. in the [AudioSeal watermarking system](https://arxiv.org/abs/2401.17264), is to use a different learnable pair of embedding vectors per sequence index in the message. That is, the first index's 0/1 vectors are separate from the second index's 0/1 vectors, and so on. It is common to sum the resulting embedding sequence along the message dimension to obtain a single vector that represents all bit values and positions.

In [None]:
# Initialize an embedding table with two entries per bit
emb_table_with_pos = torch.nn.Embedding(2 * n_bits, config["model_dim"])

# Modify message so that each entry now holds an index into the correct
# bit value / position index in our embedding table
msg_with_pos = msg + 2 * torch.arange(n_bits)

# Get bit value / position embeddings
embedded_with_pos = emb_table_with_pos(msg_with_pos).sum(0)

print(
    f"Original message: {msg.tolist()}\n"
    f"Message with offsets: {msg_with_pos.tolist()}\n"
    f"Summed embedded message tensor shape: {embedded_with_pos.shape}"
)

In practice, we could represent our message in a number of ways. For now though, we'll stick with this format and move on to the process of actually "hiding" our messages. 

How can we modify a given sequence of values (e.g. representing audio data) to incorporate information from our message? One simple mechanism is __cross-attention__, which is similar to the self-attention operation at the core of the transformer architecture. Whereas self-attention defines and operates on relationships between elements within a single sequence, cross-attention defines and operates on relationships between two sequences. In our case, we will allow each position in a sequence representing our audio data to "attend to" a representation of our watermark message. A nice thing about this method is that it works whether we squeeze our message down into a single vector (like in AudioSeal) or keep it as a sequence of values (as long as we remember to sprinkle in some positional encoding!).

In [None]:
# The cross-attention operation:
# 1. Project "audio data" sequence to obtain query sequence
# 2. Project "message" sequence to obtain key and value sequences
# 3. Split into multiple "heads" along hidden dimension
# 4. Perform standard attention using keys/queries/values
k_proj = torch.nn.Linear(config["model_dim"], config["model_dim"])
v_proj = torch.nn.Linear(config["model_dim"], config["model_dim"])
q_proj = torch.nn.Linear(config["model_dim"], config["model_dim"])

# 1.
q = q_proj(x)                                 # (n_batch, seq_len, model_dim)

# 2.
k = k_proj(embedded_with_pos.view(1, 1, -1))  # (n_batch, 1, model_dim)
v = v_proj(embedded_with_pos.view(1, 1, -1))  # (n_batch, 1, model_dim)

# 3.
n_batch, tgt_len, _ = q.shape
_, src_len, model_dim = k.shape
n_heads = 2
head_dim = model_dim // n_heads

q = q.view(n_batch, tgt_len, n_heads, head_dim).transpose(1, 2)  # (n_batch, n_heads, seq_len, head_dim)
k = k.view(n_batch, src_len, n_heads, head_dim).transpose(1, 2)  # (n_batch, n_heads, 1, head_dim)
v = v.view(n_batch, src_len, n_heads, head_dim).transpose(1, 2)  # (n_batch, n_heads, 1, head_dim)

# 4.
cross_attn_output = torch.nn.functional.scaled_dot_product_attention(
    q,
    k,
    v,
    attn_mask=None,
)  # (n_batch, n_heads, seq_len, head_dim)
cross_attn_output = cross_attn_output.transpose(1, 2).contiguous().view(n_batch, tgt_len, -1)  # (n_batch, seq_len, model_dim)

plt.imshow(x[0].T, aspect="auto", origin="lower", interpolation="none")
plt.xlabel("Frames")
plt.ylabel("Channel dim")
plt.title("")
plt.show()

plt.imshow(cross_attn_output[0].T.detach(), aspect="auto", origin="lower", interpolation="none")
plt.show()

We can interleave many such cross-attention layers within our network to repeatedly embed aspects of the watermark message. Which brings us to...

## The Watermark Embedder

The watermark embedding network is tasked with modifying a given audio signal to incorporate a given message. We will do so in the frequency domain, modifying a magnitude spectrogram representation of our audio.

In [None]:
class Embedder(torch.nn.Module):

    def __init__(
        self, 
        sample_rate: int,
        stft_params: STFTParams,
        n_bits: int = 8,
        gate: str = "sigmoid",
        *args,
        **kwargs,
    ):
        super().__init__()
        
        # Transformer backbone
        self.transformer = Transformer(*args, **kwargs)

        # Construct message-embedding cross-attention blocks to interleave 
        # within transformer backbone
        cross_attn = [
            MessageBlock(gate=gate, *args, **kwargs) 
            for _ in range(self.transformer.num_layers)
        ]

        self.cross_attn = torch.nn.ModuleList(cross_attn)

        # AudioSeal-style message embedding from above, with some extra bells
        # and whistles
        self.msg_emb = MessageEmbedding(n_bits, self.transformer.model_dim)

        # Project inputs to match hidden dimension
        self.in_proj = torch.nn.Linear(
            stft_params.window_length // 2 + 1, self.transformer.model_dim,
        )
        
        # Project outputs to match spectrogram dimension
        self.out_proj = torch.nn.Linear(
            self.transformer.model_dim, stft_params.window_length // 2 + 1,
        )
        
        # Ensure non-negative outputs
        self.softplus = torch.nn.Softplus()
        
        self.sample_rate = sample_rate
        self.stft_params = stft_params
        self.n_bits = n_bits
    
    def forward(
        self, 
        signal: AudioSignal, 
        msg: torch.Tensor,
        signal_lengths: torch.Tensor = None,
        msg_mask: torch.Tensor = None,
    ):
        
        assert msg.ndim == 2  # (n_batch, msg_len)
        n_batch = signal.batch_size

        # If no signal lengths provided, assume full-length
        if signal_lengths is None:
            signal_lengths = torch.full(
                (n_batch,),
                signal.signal_length, 
                dtype=torch.long, 
                device=signal.device
            )

        # Compute spectrogram of input audio
        watermarked = signal.clone().resample(self.sample_rate)
        watermarked.stft_params = self.stft_params
        watermarked.stft()

        mag = watermarked.magnitude  # (n_batch, n_channels, n_freq, n_frames)
        phase = watermarked.phase    # (n_batch, n_channels, n_freq, n_frames)

        # Project to match model dimension
        x = self.in_proj(mag.mean(1).transpose(1, 2))  # (n_batch, n_frames, model_dim)

        # Convert signal lengths to spectrogram frame resolution
        lengths = signal_lengths * (x.shape[1] / signal.signal_length)
        lengths = lengths.clamp(max=x.shape[1]).long()

        # Construct padding mask to prevent attention to padded frames
        lengths_expanded = lengths.to(x.device).view(n_batch, 1, 1)  # (n_batch, 1, 1)
        range_tensor = torch.arange(x.shape[1], device=x.device).view(1, 1, x.shape[1])
        padding_mask = (range_tensor < lengths_expanded).repeat(1, x.shape[1], 1)  # (n_batch, n_frames, n_frames)
        
        # Embed message
        msg_emb = self.msg_emb(msg, mask=msg_mask).unsqueeze(1)  # (n_batch, 1, model_dim)
        
        # DEBUG: issue appears to be here, at least in part
        # Pass through model
        for i, (attn_layer, cross_attn_layer) in enumerate(
            zip(self.transformer.layers, self.cross_attn)
        ):
            x = cross_attn_layer(x, msg_emb)
            x = attn_layer(x, padding_mask)  # (n_batch, n_frames, model_dim)

        # Project to obtain a multiplicative magnitude spectrogram mask and
        # apply softplus to ensure nonnegative
        out = self.out_proj(x)  # (n_batch, n_frames, n_freq)
        out = self.softplus(out)

        # Apply multiplicative mask
        mag = mag * out.unsqueeze(1).transpose(2, 3)  # (n_batch, n_channels, n_freq, n_frames)

        # Invert spectrogram to obtain audio
        watermarked.magnitude = mag
        watermarked.phase = phase
        watermarked = watermarked.istft(
            window_length=self.stft_params.window_length,
            hop_length=self.stft_params.hop_length,
            window_type=self.stft_params.window_type,
            match_stride=self.stft_params.match_stride,
            length=watermarked.signal_length,
        )

        # Restore sample rate
        watermarked.resample(signal.sample_rate)
        watermarked.stft_params = signal.stft_params

        # Ensure length matches
        watermarked.audio_data = watermarked.audio_data[..., :signal.signal_length]
        watermarked.audio_data = torch.nn.functional.pad(
            watermarked.audio_data, 
            (0, max(0, signal.signal_length - watermarked.signal_length))
        )

        return watermarked

In [None]:
# Initialize embedding network
embedder = Embedder(sample_rate, stft_params, n_bits, **config)
print(f"Embedder parameters: {count_parameters(embedder)}")

# Load audio data
signal = AudioSignal(ASSETS_DIR / "audio" / "bryan_0.wav").resample(sample_rate)
signal = signal[..., :int(signal.sample_rate * max_len_s)]
signal.widget()

# Embed a random watermark message in signal
watermarked = embedder(signal, torch.randint(0, 2, (signal.batch_size, n_bits))).detach()
watermarked.widget()

# Examine difference between original and "watermarked" signal
(watermarked - signal).widget()

Note that we haven't actually watermarked anything here -- our network is randomly initialized and has not yet been trained to hide messages in a recoverable manner! How can we train our embedder to do this? We'll need to start by defining...

## The Watermark Detector

Our detector inverts the processing performed by the embedder: given an audio signal, it predicts a fixed-length watermark vector. To map arbitrary-length audio recordings to watermark vector predictions, we'll use __attention pooling__. If we prepend a learnable latent vector to our encoded audio sequence and use our transformer's output from only this sequence position to predict encoded watermark vectors, our detector network will naturally learn to pool relevant information from across sequences into this position. 

In [None]:
class Detector(torch.nn.Module):

    def __init__(
        self, 
        sample_rate: int,
        stft_params: STFTParams,
        n_bits: int = 8,
        *args,
        **kwargs,
    ):
        super().__init__()
        
        # Transformer backbone
        self.transformer = Transformer(*args, **kwargs)

        # Learnable "pool" embedding
        self.pool_emb = torch.nn.Parameter(
            torch.randn(self.transformer.model_dim),
            requires_grad=True,
        )
        
        # Project inputs to match hidden dimension
        self.in_proj = torch.nn.Linear(
            stft_params.window_length // 2 + 1, self.transformer.model_dim,
        )
        
        # Project outputs to obtain watermark vector predictions
        self.out_proj = torch.nn.Linear(
            self.transformer.model_dim, n_bits,
        )
        
        self.sample_rate = sample_rate
        self.stft_params = stft_params
        self.n_bits = n_bits
    
    def forward(
        self, 
        signal: AudioSignal, 
        signal_lengths: torch.Tensor = None,
    ):
        
        n_batch = signal.batch_size

        # If no signal lengths provided, assume full-length
        if signal_lengths is None:
            signal_lengths = torch.full(
                (n_batch,),
                signal.signal_length, 
                dtype=torch.long, 
                device=signal.device
            )

        # Compute spectrogram of input audio
        orig_signal_length = signal.signal_length
        signal = signal.clone().resample(self.sample_rate)
        signal.stft_params = self.stft_params
        signal.stft()

        mag = signal.magnitude  # (n_batch, n_channels, n_freq, n_frames)

        # Project to match model dimension
        x = self.in_proj(mag.mean(1).transpose(1, 2))  # (n_batch, n_frames, model_dim)

        # Convert signal lengths to spectrogram frame resolution
        lengths = signal_lengths * (x.shape[1] / orig_signal_length)
        lengths = lengths.clamp(max=x.shape[1]).long()

        # Prepend learnable "pool" embedding and adjust lengths accordingly
        x = torch.cat(
            [self.pool_emb.view(1, 1, -1).repeat(n_batch, 1, 1), x],
            dim=1,
        )  # (n_batch, n_frames + 1, model_dim)
        lengths += 1  # (n_batch,)
        
        # Construct padding mask to prevent attention to padded frames
        lengths_expanded = lengths.to(x.device).view(n_batch, 1, 1)  # (n_batch, 1, 1)
        range_tensor = torch.arange(x.shape[1], device=x.device).view(1, 1, x.shape[1])
        padding_mask = (range_tensor < lengths_expanded).repeat(1, x.shape[1], 1)  # (n_batch, n_frames + 1, n_frames + 1)
        
        # Pass through model
        for i, attn_layer in enumerate(self.transformer.layers):
            x = attn_layer(x, padding_mask)  # (n_batch, n_frames + 1, model_dim)

        # Project output of first frame (corresponding to "pool" embedding) to obtain
        # watermark prediction
        out = self.out_proj(x[:, 0, :])  # (n_batch, n_bits)

        # Bound to [0, 1]
        return torch.sigmoid(out)  # (n_batch, n_bits)

In [None]:
# Initialize detector network
detector = Detector(sample_rate, stft_params, n_bits, **config)
print(f"Detector parameters: {count_parameters(detector)}")

# Pass watermarked audio through detector to obtain watermark prediction
pred = detector(signal).detach() > 0.5

# Plot detection results; for untrained embedder/detector, we should see
# random (~50%) accuracy in recovering watermark vector
plt.plot(msg.flatten().tolist(), label="actual watermark")
plt.plot(pred.flatten().tolist(), label="predicted watermark")
plt.title(f"Actual vs. predicted watermark vector (accuracy={(pred.flatten()==msg.flatten()).float().mean() :0.2f})")
plt.legend()
plt.show()

## Training End-to-End

Now that we have our embedder and detector networks, we want to train them jointly to embed and detect watermarks! We'll chain together the two steps above, passing the embedder's output to the detector as input, and then compute two losses: a __detection loss__ that encourages the detector to accurately identify the embedded watermark message, and a __perceptual transparency loss__ that encourages the embedder to produce watermarked audio that differs from the input as little as possible. The detection loss is computed on the detector's outputs, and its gradients flow through both the detector and embedder networks; the perceptual transparency loss is computed on the embedder's outputs, and its gradients flow only to the embedder.

In [None]:
# Initialize networks
embedder = Embedder(sample_rate, stft_params, n_bits, **config).to(device)
detector = Detector(sample_rate, stft_params, n_bits, **config).to(device)

print(f"Embedder parameters: {count_parameters(embedder)}")
print(f"Detector parameters: {count_parameters(detector)}")

opt_embedder = torch.optim.AdamW(embedder.parameters(), lr=3e-4)
opt_detector = torch.optim.AdamW(detector.parameters(), lr=3e-4)

# Load audio data
loader = AudioLoader(sources=[DATA_DIR/"LibriTTS_R/train-clean-100"])
dataset = AudioDataset(
    loader, 
    sample_rate=sample_rate,
    n_examples=1_000_000,       # Number of unique rows to load
    duration=max_len_s,         # Pad/trim all audio to this duration
    loudness_cutoff=-40,        # Sample random excerpts until minimum loudness (dB) cutoff is met
    num_channels=1,             # If 1, downmix all audio to mono
    without_replacement=False,  # Sample audio files with/without replacement
)

# Initialize data loader
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    collate_fn=collate,
    shuffle=True
)

In [None]:
acc_d = []
loss_d = []
loss_pt = []

train_iter = iter(dataloader)

pbar = tqdm(range(n_train_steps))
for i in pbar:
    
    embedder.train()
    detector.train()
    
    opt_embedder.zero_grad()
    opt_detector.zero_grad()

    try:
        batch = next(train_iter)
    except StopIteration:
        train_iter = iter(dataloader)
        batch = next(train_iter) 

    idx = batch["idx"].to(device)
    signal = batch["signal"].to(device)                     # (n_batch, n_channels, n_samples)
    signal_lengths = batch["signal_lengths"].to(device)     # (n_batch,)
    
    # Sample random watermark message
    msg = torch.randint(
        0, 2, 
        (signal.batch_size, n_bits), 
        device=device
    )  # (n_batch, n_bits)

    # Embed message
    watermarked = embedder(signal, msg, signal_lengths)  # (n_batch, n_channels, n_samples)

    # Detect message
    msg_pred = detector(watermarked, signal_lengths)  # (n_batch, n_bits)
    
    # Compute detection loss
    _loss_d = torch.nn.functional.binary_cross_entropy(
        msg_pred, msg.float(), reduction="none",
    ).mean(dim=-1)  # (n_batch,)
    
    # Compute perceptual transparency loss
    _loss_pt = (
        watermarked.magnitude - signal.magnitude
    ).reshape(n_batch, -1).norm(dim=-1)  # (n_batch,)

    # Combine losses
    _loss = _loss_d + 0.01 * _loss_pt  # (n_batch,)

    # Backward pass
    _loss.mean().backward()
    torch.nn.utils.clip_grad_norm_(embedder.parameters(), 1.0)
    torch.nn.utils.clip_grad_norm_(detector.parameters(), 1.0)

    # Update
    opt_embedder.step()
    opt_detector.step()

    # Logging
    with torch.no_grad():
        acc_d += [((msg_pred > 0.5) == msg).float().mean().item()]
        loss_d += [_loss_d.mean().item()]
        loss_pt += [_loss_pt.mean().item()]

    pbar.set_description(
        f"Detection accuracy: {acc_d[-1] :0.2f}, "
        f"Detection loss: {loss_d[-1] :0.2f}, "
        f"Perceptual transparency loss: {loss_pt[-1] :0.2f}, "
    )    

In [None]:
plt.plot(acc_d)
plt.title("Detection accuracy (training)")
plt.xlabel("Training step")
plt.show()

plt.plot(loss_d)
plt.title("Detection loss (training)")
plt.xlabel("Training step")
plt.show()

plt.plot(loss_pt)
plt.title("Perceptual transparency loss (training)")
plt.xlabel("Training step")
plt.show()

We can now try out our trained watermark system on an unseen recording!

In [None]:
# Load audio
signal = AudioSignal(ASSETS_DIR / "audio" / "bryan_0.wav").resample(sample_rate).to(device)
signal = signal[..., :int(max_len_s * signal.sample_rate)]

msg = torch.randint(0, 2, (1, n_bits), device=device)
signal.clone().detach().cpu().widget()

# Embed watermark
watermarked = embedder(signal, msg)
watermarked.clone().detach().cpu().widget()

# Run detection on both original and watermarked signals
msg_pred_unwatermarked = detector(signal) > 0.5
msg_pred_watermarked = detector(watermarked) > 0.5

# Evanluate perceptual transparency
print(
    f"Watermark SNR: {snr(watermarked, signal).item() :0.2f}\n"
    f"Watermark SI-SDR: {si_sdr(watermarked, signal).item() :0.2f}\n"
)

# Plot detection results
plt.plot(msg.flatten().tolist(), label="actual watermark")
plt.plot(msg_pred_unwatermarked.flatten().tolist(), label="predicted (unwatermarked)")
plt.plot(msg_pred_watermarked.flatten().tolist(), label="predicted (watermarked)")
plt.xlabel("Watermark message bit")
plt.title(
    f"Actual vs. predicted watermark vector\n"
    f"(unwatermarked accuracy={(msg_pred_unwatermarked.flatten()==msg.flatten()).float().mean() :0.2f})\n"
    f"(watermarked accuracy={(msg_pred_watermarked.flatten()==msg.flatten()).float().mean() :0.2f})\n"
)
plt.legend()
plt.show()

## Robustness

Let's put our trained watermarking system to the test in a more rigorous manner. We'll sample a large number of recordings unseen during training, sample corresponding random watermark messages, and compute the accuracy of messages recovered by our detector from both watermarked and unwatermarked audio. Using these accuracies as detections cores, we'll evaluate the performance of our system in terms of the achievable true positive rate at a fixed false positive rate

In [None]:
# Load audio data
loader = AudioLoader(sources=[DATA_DIR/"LibriTTS_R/test-clean"])  # Test set!
dataset = AudioDataset(
    loader, 
    sample_rate=sample_rate,
    n_examples=1_000,
    duration=max_len_s,
    loudness_cutoff=-40,     
    num_channels=1,
    without_replacement=True,
)

# Initialize data loader
val_dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    collate_fn=collate,
    shuffle=True
)

# Watermarked and unwatermarked scores
uwm_scores = []
wm_scores = []

pbar = tqdm(val_dataloader)
for batch in pbar:
    
    embedder.eval()
    detector.eval()

    idx = batch["idx"].to(device)
    signal = batch["signal"].to(device)                     # (n_batch, n_channels, n_samples)
    signal_lengths = batch["signal_lengths"].to(device)     # (n_batch,)
    
    # Sample random watermark message
    msg = torch.randint(
        0, 2, 
        (signal.batch_size, n_bits), 
        device=device
    )  # (n_batch, n_bits)

    with torch.no_grad():
    
        # Embed message
        watermarked = embedder(signal, msg)  # (n_batch, n_channels, n_samples)
    
        # Detect message
        msg_pred_wm = detector(watermarked)  # (n_batch, n_bits)
        msg_pred_uwm = detector(signal)      # (n_batch, n_bits)

        acc_wm = ((msg_pred_wm > 0.5) == msg).float().mean(dim=-1)    # (n_batch,)
        acc_uwm = ((msg_pred_uwm > 0.5) == msg).float().mean(dim=-1)  # (n_batch,)

        wm_scores += [acc_wm]
        uwm_scores += [acc_uwm]

wm_scores = torch.cat(wm_scores, dim=0)
uwm_scores = torch.cat(uwm_scores, dim=0)

print(
    f"TPR @ 10% FPR: {tpr_at_fpr(wm_scores, uwm_scores, 0.10) :0.2f}\n"
    f"TPR @ 1% FPR: {tpr_at_fpr(wm_scores, uwm_scores, 0.01) :0.2f}\n"
    f"TPR @ 0.1% FPR: {tpr_at_fpr(wm_scores, uwm_scores, 0.001) :0.2f}\n"
)

Not too shabby for a bare-bones end-to-end system! And importantly, we didn't need to hand-craft a clever embedding or detections scheme -- we let our networks learn everything for us. Now what happens if we perturb the watermarked audio before detection?

In [None]:
create_csv(
    audio_files=list((Path(NOISE_DIR)).rglob("*.wav")),
    output_csv=MANIFESTS_DIR / "noise_room.csv", 
)
create_csv(
    audio_files=list((Path(RIR_DIR)).rglob("*.wav")),
    output_csv=MANIFESTS_DIR / "rir_real.csv",
)

noise = BackgroundNoise(
    snr=("uniform", 20.0, 30.0),  # Sample noise level uniformly in [10, 30]dB
    sources=[MANIFESTS_DIR / "noise_room.csv"],
    eq_amount=("const", 1.0),     # Sample EQ level as a fixed value of 1.0
    n_bands=3,
    prob=1.0,
    loudness_cutoff=None,
)

reverb = RoomImpulseResponse(
    drr=("uniform", 10.0, 30.0),   # Sample reverb direct-reverberant ratio uniformly in [0, 30]dB
    sources=[MANIFESTS_DIR / "rir_real.csv"],
    eq_amount=("const", 1.0),     # Sample EQ level as a fixed value of 1.0
    n_bands=6,
    prob=1.0,
    use_original_phase=False,
    offset=0.0,
    duration=1.0,
)

noise_and_reverb = Compose(noise, reverb)

# Watermarked and unwatermarked scores
uwm_scores = []
wm_scores = []

pbar = tqdm(val_dataloader)
for batch in pbar:
    
    embedder.eval()
    detector.eval()

    idx = batch["idx"].to(device)
    signal = batch["signal"].to(device)                     # (n_batch, n_channels, n_samples)
    signal_lengths = batch["signal_lengths"].to(device)     # (n_batch,)
    
    # Sample random watermark message
    msg = torch.randint(
        0, 2, 
        (signal.batch_size, n_bits), 
        device=device
    )  # (n_batch, n_bits)

    with torch.no_grad():
    
        # Embed message
        watermarked = embedder(signal, msg, signal_lengths)  # (n_batch, n_channels, n_samples)

        # Apply randomized transformations
        tfm_kwargs = noise_and_reverb.batch_instantiate(idx.tolist(), signal)
        tfm_signal = noise_and_reverb.transform(signal.clone().cpu(), **tfm_kwargs).to(device)
        tfm_watermarked = noise_and_reverb.transform(watermarked.clone().cpu(), **tfm_kwargs).to(device)
        
        # Detect message
        msg_pred_wm = detector(tfm_watermarked)  # (n_batch, n_bits)
        msg_pred_uwm = detector(tfm_signal)      # (n_batch, n_bits)

        acc_wm = ((msg_pred_wm > 0.5) == msg).float().mean(dim=-1)    # (n_batch,)
        acc_uwm = ((msg_pred_uwm > 0.5) == msg).float().mean(dim=-1)  # (n_batch,)

        wm_scores += [acc_wm]
        uwm_scores += [acc_uwm]

wm_scores = torch.cat(wm_scores, dim=0)
uwm_scores = torch.cat(uwm_scores, dim=0)

print(
    f"TPR @ 10% FPR: {tpr_at_fpr(wm_scores, uwm_scores, 0.10) :0.2f}\n"
    f"TPR @ 1% FPR: {tpr_at_fpr(wm_scores, uwm_scores, 0.01) :0.2f}\n"
    f"TPR @ 0.1% FPR: {tpr_at_fpr(wm_scores, uwm_scores, 0.001) :0.2f}\n"
)

Here, we can see that our approach fails under noise and reverb transformations. And that shouldn't be a surprise: we did not train our system to be robust to any transformations! Let's try again, but this time we'll incorporate transformations directly into our training loop. To ensure we aren't "training on the test" distribution, we'll used simplified simulations of background noise and reverberation (implemented in `wm_tutorial/tfm.py`) that can be run in parallel on the GPU without loading any files from disk. For reverberation, we'll use the noise-shaped impulse response implementation from [`dasp-pytorch`](https://github.com/csteinmetz1/dasp-pytorch/).

In [None]:
# Initialize networks
embedder = Embedder(sample_rate, stft_params, n_bits, **config).to(device)
detector = Detector(sample_rate, stft_params, n_bits, **config).to(device)

print(f"Embedder parameters: {count_parameters(embedder)}")
print(f"Detector parameters: {count_parameters(detector)}")

opt_embedder = torch.optim.AdamW(embedder.parameters(), lr=3e-4)
opt_detector = torch.optim.AdamW(detector.parameters(), lr=3e-4)

# Load audio data
loader = AudioLoader(sources=[DATA_DIR/"LibriTTS_R/train-clean-100"])
dataset = AudioDataset(
    loader, 
    sample_rate=sample_rate,
    n_examples=1_000_000,       # Number of unique rows to load
    duration=max_len_s,         # Pad/trim all audio to this duration
    loudness_cutoff=-40,        # Sample random excerpts until minimum loudness (dB) cutoff is met
    num_channels=1,             # If 1, downmix all audio to mono
    without_replacement=False,  # Sample audio files with/without replacement
)

# Initialize data loader
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    collate_fn=collate,
    shuffle=True
)

# Training transforms
fast_noise = Noise(
    snr=("uniform", 10.0, 30.0),
    eq_amount=("const", 1.0),
    prob=0.5,
)
eq = Equalizer(
    n_bands=6,
    prob=0.5,
)
speed = Speed(
    factor=("choice", (0.95, 0.96, 0.97, 0.98, 0.99, 1.01, 1.02, 1.03, 1.04, 1.05)),
    prob=0.5,
)
train_tfm = Compose(speed, fast_noise, eq)

In [None]:
acc_d = []
loss_d = []
loss_pt = []

train_iter = iter(dataloader)

pbar = tqdm(range(n_train_steps))
for i in pbar:
    
    embedder.train()
    detector.train()
    
    opt_embedder.zero_grad()
    opt_detector.zero_grad()

    try:
        batch = next(train_iter)
    except StopIteration:
        train_iter = iter(dataloader)
        batch = next(train_iter) 

    idx = batch["idx"].to(device)
    signal = batch["signal"].to(device)                     # (n_batch, n_channels, n_samples)
    signal_lengths = batch["signal_lengths"].to(device)     # (n_batch,)
    
    # Sample random watermark message
    msg = torch.randint(
        0, 2, 
        (signal.batch_size, n_bits), 
        device=device
    )  # (n_batch, n_bits)

    # Embed message
    watermarked = embedder(signal, msg, signal_lengths)  # (n_batch, n_channels, n_samples)

    # Apply transformations
    tfm_kwargs = train_tfm.batch_instantiate(idx.tolist(), watermarked)
    tfm_watermarked = train_tfm.transform(watermarked.clone(), **tfm_kwargs)
    
    # Detect message
    msg_pred = detector(tfm_watermarked, signal_lengths)  # (n_batch, n_bits)
    
    # Compute detection loss
    _loss_d = torch.nn.functional.binary_cross_entropy(
        msg_pred, msg.float(), reduction="none",
    ).mean(dim=-1)  # (n_batch,)
    
    # Compute perceptual transparency loss
    _loss_pt = (
        watermarked.magnitude - signal.magnitude
    ).reshape(n_batch, -1).norm(dim=-1)  # (n_batch,)

    # Combine losses; apply "curriculum" to down-weight perceptual transparency 
    # loss to prevent it from dominating detection loss, which is trickier to
    # optimize when training through simulated audio transformations
    _loss = _loss_d + min(0.001, 0.0001 + (0.000001 * i)) * _loss_pt  # (n_batch,)

    # Backward pass
    _loss.mean().backward()
    torch.nn.utils.clip_grad_norm_(embedder.parameters(), 1.0)
    torch.nn.utils.clip_grad_norm_(detector.parameters(), 1.0)

    # Update
    opt_embedder.step()
    opt_detector.step()

    # Logging
    with torch.no_grad():
        acc_d += [((msg_pred > 0.5) == msg).float().mean().item()]
        loss_d += [_loss_d.mean().item()]
        loss_pt += [_loss_pt.mean().item()]

    pbar.set_description(
        f"Detection accuracy: {acc_d[-1] :0.2f}, "
        f"Detection loss: {loss_d[-1] :0.2f}, "
        f"Perceptual transparency loss: {loss_pt[-1] :0.2f}, "
    )    

In [None]:
plt.plot(acc_d)
plt.title("Detection accuracy (training)")
plt.show()

plt.plot(loss_d)
plt.title("Detection loss (training)")
plt.show()

plt.plot(loss_pt)
plt.title("Perceptual transparency loss (training)")
plt.show()

In [None]:
# Load audio
signal = AudioSignal(ASSETS_DIR / "audio" / "bryan_0.wav").resample(sample_rate).to(device)
signal = signal[..., :int(max_len_s * signal.sample_rate)]

msg = torch.randint(0, 2, (1, n_bits), device=device)
signal.clone().detach().cpu().widget()

# Embed watermark
watermarked = embedder(signal, msg)
watermarked.clone().detach().cpu().widget()
(signal - watermarked).clone().detach().cpu().widget()

# Run detection on both original and watermarked signals
msg_pred_unwatermarked = detector(signal) > 0.5
msg_pred_watermarked = detector(watermarked) > 0.5

# Evanluate perceptual transparency
print(
    f"Watermark SNR: {snr(watermarked, signal).item() :0.2f}\n"
    f"Watermark SI-SDR: {si_sdr(watermarked, signal).item() :0.2f}\n"
)

# Plot detection results
plt.plot(msg.flatten().tolist(), label="actual watermark")
plt.plot(msg_pred_unwatermarked.flatten().tolist(), label="predicted (unwatermarked)")
plt.plot(msg_pred_watermarked.flatten().tolist(), label="predicted (watermarked)")
plt.xlabel("Watermark message bit")
plt.title(
    f"Actual vs. predicted watermark vector\n"
    f"(unwatermarked accuracy={(msg_pred_unwatermarked.flatten()==msg.flatten()).float().mean() :0.2f})\n"
    f"(watermarked accuracy={(msg_pred_watermarked.flatten()==msg.flatten()).float().mean() :0.2f})\n"
)
plt.legend()
plt.show()

Now that we've trained with transformations, let's see if we do any better on our robustness evaluation!

In [None]:
# Watermarked and unwatermarked scores
uwm_scores = []
wm_scores = []

pbar = tqdm(val_dataloader)
for batch in pbar:
    
    embedder.eval()
    detector.eval()

    idx = batch["idx"].to(device)
    signal = batch["signal"].to(device)                     # (n_batch, n_channels, n_samples)
    signal_lengths = batch["signal_lengths"].to(device)     # (n_batch,)
    
    # Sample random watermark message
    msg = torch.randint(
        0, 2, 
        (signal.batch_size, n_bits), 
        device=device
    )  # (n_batch, n_bits)

    with torch.no_grad():
    
        # Embed message
        watermarked = embedder(signal, msg, signal_lengths)  # (n_batch, n_channels, n_samples)

        # Apply randomized transformations
        tfm_kwargs = noise_and_reverb.batch_instantiate(idx.tolist(), signal)
        tfm_signal = noise_and_reverb.transform(signal.clone().cpu(), **tfm_kwargs).to(device)
        tfm_watermarked = noise_and_reverb.transform(watermarked.clone().cpu(), **tfm_kwargs).to(device)
        
        # Detect message
        msg_pred_wm = detector(tfm_watermarked)  # (n_batch, n_bits)
        msg_pred_uwm = detector(tfm_signal)      # (n_batch, n_bits)

        acc_wm = ((msg_pred_wm > 0.5) == msg).float().mean(dim=-1)    # (n_batch,)
        acc_uwm = ((msg_pred_uwm > 0.5) == msg).float().mean(dim=-1)  # (n_batch,)

        wm_scores += [acc_wm]
        uwm_scores += [acc_uwm]

wm_scores = torch.cat(wm_scores, dim=0)
uwm_scores = torch.cat(uwm_scores, dim=0)

print(
    f"TPR @ 10% FPR: {tpr_at_fpr(wm_scores, uwm_scores, 0.10) :0.2f}\n"
    f"TPR @ 1% FPR: {tpr_at_fpr(wm_scores, uwm_scores, 0.01) :0.2f}\n"
    f"TPR @ 0.1% FPR: {tpr_at_fpr(wm_scores, uwm_scores, 0.001) :0.2f}\n"
)

## A Final Note on Tradeoffs

We were able to improve our robustness, but with some notable costs:
* Training took longer due to processing from our simulated transformations; this can sometimes be reduced with more efficient implementations, but generally slows things down
* Our "robust" watermark was more clearly audible than our original watermark; in fact, we needed to use a "curriculum" to down-weight our perceptual transparency loss to allow our watermark to be "loud" enough to withstand audio transformations!


When training end-to-end watermarking systems, we need to take into account the inherent trade-offs between robustness, perceptual transparency, and capacity. For example, trying to encode longer watermark messages (i.e. _increasing capacity_) typically comes at a cost to transparency or robustness -- you can try increasing `n_bits` in this notebook and see for yourself!