<a href="https://colab.research.google.com/github/mayibongwemoyo/dawm/blob/main/examples/comprehensive_voxpopuli_Incremental_watermarking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture

import sys
!pip install torchaudio soundfile matplotlib scipy datasets pandas seaborn

In [2]:
%%capture
!pip install audioseal


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
import sys
sys.path.append('/content/drive/MyDrive/dawm/examples')
import noteb00k
from datasets import load_dataset
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torchaudio
from audioseal import AudioSeal
import numpy as np

In [5]:
# Load a subset of VoxPopuli (example: 50 audios)
dataset = load_dataset("facebook/voxpopuli", "en", split="validation", streaming=True)
audios_to_test = []
for i, example in enumerate(dataset):
    if i >= 50:
        break
    audio_array = example["audio"]["array"]
    sr = example["audio"]["sampling_rate"]
    audios_to_test.append((torch.tensor(audio_array).float(), sr))

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [6]:
def process_audio(audio, sr):
    # Ensure audio is a PyTorch tensor
    if isinstance(audio, np.ndarray):
        audio = torch.from_numpy(audio).float()

    # Convert to 3D: (batch=1, channels=1, time)
    if audio.dim() == 1:  # (time) → (1, 1, time)
        audio = audio.unsqueeze(0).unsqueeze(0)
    elif audio.dim() == 2:  # (channels, time) → (1, channels, time)
        audio = audio.unsqueeze(0)
    elif audio.dim() == 3:  # Already (batch, channels, time)
        pass
    else:
        raise ValueError(f"Unsupported audio shape: {audio.shape}")

    # Ensure mono audio by averaging channels
    if audio.shape[1] > 1:
        audio = audio.mean(dim=1, keepdim=True)

    # Resample to 16kHz if needed
    if sr != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
        audio = resampler(audio)
        sr = 16000

    # Generate messages (4 real, 4 fake)
    real_messages = [torch.randint(0, 2, (1, 16)) for _ in range(4)]
    fake_messages = [torch.randint(0, 2, (1, 16)) for _ in range(4)]

    # Embed ONLY real watermarks
    watermarked_audio = audio.clone()
    for i in range(len(real_messages)):
        generator = AudioSeal.load_generator("audioseal_wm_16bits")
        print(f"Shape before generator {i+1}: {watermarked_audio.shape}")
        watermarked_audio = generator(
            watermarked_audio,
            sample_rate=sr,
            alpha=0.5/(i+1)
        )
        print(f"Shape after generator {i+1}: {watermarked_audio.shape}")

    # Test detection - FIXED LOOP HERE
    results = []
    for msg_idx, msg in enumerate(real_messages + fake_messages):  # Changed to enumerate
        detector = AudioSeal.load_detector("audioseal_detector_16bits")
        detector.message = msg
        prob, detected_msg = detector.detect_watermark(watermarked_audio, sample_rate=sr)

        prob_value = prob if isinstance(prob, float) else prob.item()
        ber = (msg != detected_msg.round()).float().mean().item()

        results.append({
            "msg_type": "real" if msg_idx < len(real_messages) else "fake",  # Use msg_idx
            "ber": ber,
            "prob": prob_value
        })

    # Calculate metrics
    noise = watermarked_audio - audio
    snr = 10 * torch.log10(audio.pow(2).mean() / noise.pow(2).mean()).item()

    # Calculate false positives
    null_probs = []
    for _ in range(1000):
        random_msg = torch.randint(0, 2, (1, 16))
        detector = AudioSeal.load_detector("audioseal_detector_16bits")
        detector.message = random_msg
        prob, _ = detector.detect_watermark(watermarked_audio, sample_rate=sr)
        prob_value = prob if isinstance(prob, float) else prob.item()
        null_probs.append(prob_value)

    false_positive_rate = sum(p >= 0.99 for p in null_probs) / len(null_probs)

    return {
        "results": results,
        "snr": snr,
        "false_positive_rate": false_positive_rate
    }

In [None]:
all_results = []
for audio, sr in audios_to_test:
    metrics = process_audio(audio, sr)
    all_results.append(metrics)

Shape before generator 1: torch.Size([1, 1, 245439])
Shape after generator 1: torch.Size([1, 1, 245439])
Shape before generator 2: torch.Size([1, 1, 245439])
Shape after generator 2: torch.Size([1, 1, 245439])
Shape before generator 3: torch.Size([1, 1, 245439])
Shape after generator 3: torch.Size([1, 1, 245439])
Shape before generator 4: torch.Size([1, 1, 245439])
Shape after generator 4: torch.Size([1, 1, 245439])
Shape before generator 1: torch.Size([1, 1, 129279])
Shape after generator 1: torch.Size([1, 1, 129279])
Shape before generator 2: torch.Size([1, 1, 129279])
Shape after generator 2: torch.Size([1, 1, 129279])
Shape before generator 3: torch.Size([1, 1, 129279])
Shape after generator 3: torch.Size([1, 1, 129279])
Shape before generator 4: torch.Size([1, 1, 129279])
Shape after generator 4: torch.Size([1, 1, 129279])
Shape before generator 1: torch.Size([1, 1, 189760])
Shape after generator 1: torch.Size([1, 1, 189760])
Shape before generator 2: torch.Size([1, 1, 189760])
Sh

In [None]:
# Combine results
results_df = pd.DataFrame([item for res in all_results for item in res["results"]])

# Calculate statistics
print("Average SNR:", np.mean([res["snr"] for res in all_results]))
print("False Positive Rate:", np.mean([res["false_positive_rate"] for res in all_results]))
print("\nDetection Performance:")
print(results_df.groupby("msg_type").mean())

In [None]:
# all_results = []
# for audio, sr in audios_to_test:
#     # Ensure audio is tensor and enforce shape
#     if isinstance(audio, np.ndarray):
#         audio = torch.from_numpy(audio).float()
#     audio = audio.unsqueeze(0).unsqueeze(0) if audio.dim() == 1 else audio
#     print(f"Processing audio with shape: {audio.shape}")  # Debug line

#     metrics = process_audio(audio, sr)
#     all_results.append(metrics)

In [None]:
# Combine SNR data
snr_df = pd.DataFrame([r["snrs"] for r in all_results], columns=[f"WM_{i+1}" for i in range(4)])
snr_summary = snr_df.mean().reset_index(name="SNR (dB)")

# Combine BER and detection probability
results_df = pd.concat([r["results"] for r in all_results])
ber_summary = results_df.groupby(["msg_type", "watermark_count"]).ber.mean().reset_index()
prob_summary = results_df.groupby(["msg_type", "watermark_count"]).prob.mean().reset_index()

# False positive rates
fp_rates = [r["false_positive_rate"] for r in all_results]

In [None]:
plt.figure(figsize=(10, 5))
sns.lineplot(data=snr_summary, x="index", y="SNR (dB)", marker="o")
plt.title("Average SNR vs. Number of Watermarks")
plt.xlabel("Watermark Count")
plt.xticks(ticks=range(4), labels=["1", "2", "3", "4"])
plt.grid(True)
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
sns.lineplot(data=ber_summary, x="watermark_count", y="ber", hue="msg_type", marker="o")
plt.title("Average BER vs. Watermark Count")
plt.xlabel("Watermark Count")
plt.ylabel("Bit Error Rate (BER)")
plt.grid(True)
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
sns.histplot(fp_rates, bins=20, kde=True)
plt.title("Distribution of False Positive Rates Across Audios")
plt.xlabel("False Positive Rate")
plt.xlim(0, 1)
plt.show()

In [None]:
plt.figure(figsize=(10, 5))
sns.kdeplot(real_probs, label="Real Watermarks", fill=True)
sns.kdeplot(fake_probs, label="Fake Watermarks", fill=True)
plt.axvline(0.5, color="red", linestyle="--", label="Random Guess Threshold")
plt.title("Detection Probability Distribution (Real vs. Fake)")
plt.xlabel("Probability")
plt.legend()
plt.show()

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (20,3)

from noteb00k import play_audio, plot_waveform_and_specgram

plot_waveform_and_specgram(audio, sr, title="Original audio")

In [None]:
play_audio(audio, sr)

In [None]:
# Initialise Audioseal models
from audioseal import AudioSeal

# Load models
generator = AudioSeal.load_generator("audioseal_wm_16bits")
detector = AudioSeal.load_detector(("audioseal_detector_16bits"))


In [None]:
# Generate 3 real messages and 1 fake
real_messages = [torch.randint(0, 2, (1, 16)) for _ in range(3)]
fake_messages = [torch.randint(0, 2, (1, 16)) for _ in range(3)]

# Combine messages
all_messages = real_messages + fake_messages

for i, msg in enumerate(all_messages):
    print(f"Message {i+1}: {msg.numpy().flatten()}")  # Verify uniqueness

audios = audio.unsqueeze(0)  # Add batch dimension (if missing)
# audio = audio.unsqueeze(1)  # Add channel dimension (if mono)


watermarked_audio = audios.clone()  # Start with original audio

# Embed watermarks sequentially
for idx, msg in enumerate(real_messages):
    # Generate watermark for current message
    watermark = generator.get_watermark(watermarked_audio, sample_rate=sr)

    # Apply watermark with scaled strength
    # watermarked_audio = audios + watermark
    watermarked_audio = generator(watermarked_audio, sample_rate=sr, alpha=0.5/(idx+1))


    # Calculate metrics after EACH embedding
    noise = watermarked_audio - audios
    snr = 10 * torch.log10(audios.pow(2).mean() / noise.pow(2).mean())

    print(f"\nAfter Watermark {idx+1}:")
    print("-" * 50)
    print(f"  SNR: {snr:.2f} dB")

print("\nFinal Detection Results:")
# Detect ALL previous watermarks
for detect_idx, msg in enumerate(all_messages):
    # Create NEW detector for each test
    temp_detector = AudioSeal.load_detector("audioseal_detector_16bits")
    temp_detector.message = msg

    prob, detected_msg = temp_detector.detect_watermark(watermarked_audio,sample_rate=sr, message_threshold=0.5)
    ber = (msg != detected_msg.round()).float().mean()

    status = "REAL" if detect_idx < len(real_messages) else "FAKE"
    print(f"  Watermark {detect_idx+1}: BER={ber:.2f}, Prob={prob:.2f}")


In [None]:
# Calculate empirical p-value for fake detection
null_probs = []
for _ in range(1000):  # Increase for tighter confidence intervals
    # 1. Generate random message
    random_msg = torch.randint(0, 2, (1, 16))

    # 2. Fresh detector instance
    temp_detector = AudioSeal.load_detector("audioseal_detector_16bits")
    temp_detector.message = random_msg

    # 3. Proper output handling
    detection_result = temp_detector.detect_watermark(watermarked_audio, sample_rate=sr)
    if isinstance(detection_result, tuple):
        prob = detection_result[0]  # Get probability from tuple
    else:
        prob = detection_result  # Single value

    null_probs.append(float(prob))  # Explicit conversion

# Calculate p-value with continuity correction
extreme_count = sum(p >= 0.99 for p in null_probs)  # Count near-perfect detections
p_value = (extreme_count + 1) / (len(null_probs) + 1)
print(f"False positive rate: {100*extreme_count/len(null_probs):.2f}%")
print(f"p-value: {p_value:.6f} (n={len(null_probs)})")

In [None]:
# Calculate SNR
noise = watermarked_audio - audios
snr = 10 * torch.log10(audios.pow(2).mean() / noise.pow(2).mean())
print(f"Final SNR: {snr:.2f} dB")

In [None]:
plot_waveform_and_specgram(watermarked_audio.squeeze(), sr, title="Multi-Watermarked Audio")
play_audio(watermarked_audio, sr)