# Watermarking custom model
This notebook shows how we can use Omniseal Bench Python APIs to add a new model and run the evaluation tasks on different datasets. We demonstrate with audio but this applies to all modalities too.

## Basic Concepts:
### Task:

The Task defines the dataset to be used for the evaluation, how to load data (batch size, padding, etc.) and metrics to run. The task type can be one of the three:

- **generation**: Run a generator over specific dataset to watermark each of its items, and return the quality metrics of the watermarked audios.

- **detection**: Apply a specific set of attacks over a watermarked dataset, then run detector and report the detection scores (robustness) and as well as quality metrics for each attack.

- **default**: Run end-to-end genration and detection.

### Model:
A model is a wrapper over the watermarking model. A model in Omniseal Bench can fall into 3 kinds: Generator, Detector, or both. A Generator is any model that implements the `generate_watermark()` , a Detector is any model that implments the `detect_watermark()` with a specific signature as belows.


### Task - Model compatibility:

- A **generation** task can only evaluate a Generator (model that implements `generate_watermark()` method).

- A **detection** task can only evaluate a Detector (model that implements `detect_watermark()` method).

- A **default** task can only evaluate a model that implements both `generate_watermark()` and `detect_watermark()`

In [None]:
# Some Jupyter kernels check TORCH_DISTRIBUTED_DEBUG and throw an error if not found
import os
import sys

os.environ["TORCH_DISTRIBUTED_DEBUG"] = "OFF"
sys.executable

## Running example:

The steps to build a custom model:

#### Step 1: Define the model class(es)
A model class in Omniseal Bench can fall into 3 types: Generator, Detector, or both. A Generator is any class that implements the `generate_watermark()` , a Detector is any class that implments the `detect_watermark()` with a specific signature as belows.

#### Step 2: Define the builder function:
  
  - The builder function is a function that returns an object of the model class. The function name must be strictly either "build_generator()" (that returns a Generator), "build_detector()" (that returns a Detector), or build_model()" (that returns a class being both Generator and Detector).

  - The function can have any parameters, but should contain at least one parameter "device" which defines which device the model object will be placed too.

In [None]:

# Put the following code in a module file (toy_audio.py)

from typing import Dict, List, Optional
import torch
import torch.nn as nn

# STEP 1: Define the model
# The compatibility between model and the task:
# For the model to be evaluated in a generation task ("generation"), we must implement get_watermark()
# For the model to be evaluated in a detection task ("detection"), we must implement detect_watermark()
# For the model to be evaluated in an end-to-end task ("default"), we must implement both functions
class ToyAudioWatermark:
    
    model: nn.Module
    
    def __init__(self, model: torch.nn.Module, alpha: float = 0.0001, M: float = 1.0, nbits: int = 16):
        self.model = model
        
        # A wrapper can have any additional arguments. These arguments should be set in the builder func
        self.alpha = alpha
        self.M = M
        
        # Each model should have an attribute 'nbits'. If the model does not have this attribute,
        # we must set the value `message_size` in the task. If Omniseal could not find information 
        # from either model or the task, it will raise the ValueError
        self.nbits = nbits

    @torch.inference_mode()
    def generate_watermark(
        self,
        contents: torch.Tensor,
        message: torch.Tensor,
    ) -> torch.Tensor:
        # A generate_watermark() must specific signature:
        # Args:
        #  - contents: a torch.Tensor (with batch dimension at dim=0) or a list of torch.Tensor (each without batch dimension)
        # - message: a torch.Tensor or a numpy array
        # Return:
        # - Should have the same data type as 'contents' with the same dimension

        # A dummy implementation of watermark for demo. Here we just, we just use a constant value M
        hidden = torch.full_like(contents, self.M, dtype=contents.dtype, device=contents.device)
        return contents + self.alpha * hidden
    
    @torch.inference_mode()
    def detect_watermark(
        self,
        contents: torch.Tensor,
        detection_threshold: float = 0.0,
        message_threshold: float = 0.0,
        detection_bits: Optional[int] = None,
    ) -> Dict[str, torch.Tensor]:
        # A detect_watermark() must have a specific signature:
        # Args:
        #  - contents: a torch.Tensor (with batch dimension at dim=0) or a list of torch.Tensor (each without batch dimension)
        #  - message_threshold: threshold used to convert the watermark output (probability
        #    of each bits being 0 or 1) into the binary n-bit message.
        #  - detection_threshold: threshold to convert the softmax output to binary indicating
        #    the probability of the content being watermarked
        #  - detection_bits: number of bits reserved for calculating detection accuracy. 
        # Returns:
        #  - a dictionary of with some keys such as:
        #    - "prediction": The prediction probability of the content being watermarked or not. The dimension should be 1 for batch size of `B`.
        #    - "message": The secret message of dimension `B x nbits`
        #    - "detection_bits": The list of bits reserved to calculating the detection accuracy.
        #   
        #    One of "prediction" and "detection_bits" must be provided. "message" is optional
        #    If "message" is returned, Omniseal Bench will compute message accuracy scores: "bit_acc", "word_acc", "p_value", "capacity", and "log10_p_value"
        #    Otherwise, these metrics will be skipped
        
        B = len(contents)

        # Dummy implementation
        if self.alpha == 0:            
            return {
                "prediction": torch.zeros((B,), device=contents.device),  # prediction probablity is a 1-dim vector
                "message": torch.zeros((B, self.nbits), device=contents.device)  # Dummy message for demonstration
            }
        return {
            "prediction": torch.ones((B,), device=contents.device),  # # prediction probablity is a 
            "message": torch.ones((B, self.nbits), dtype=contents.dtype, device=contents.device)
        }
    
    
    
# STEP 2: Define the builder function.
# 
# The function can have any parameters, but should contain at least one parameter "device" which defines which device the model object will be placed too.
# It is advisable to have the parameters() of the model class __init__() match the arguments of this function.
def build_model(alpha: float = 0.0001, M: float = 1.0, nbits: int = 16, device: str = "cpu") -> ToyWatermark:
    
    model = torch.nn.Identity() # no actual model, just a placeholder
    
    return ToyWatermark(model=model, alpha=alpha, M=M, nbits=nbits).eval().to(device)

Finally, we run step 3. We demonstrate with audio but this applies to all modalities too.

In [None]:
mod = load_module_from_file("my_module", "toy_audio.py")


In [None]:
mod1 = importlib.import_module("my_module")

In [None]:
# Example task that evaluates a custom model

from omnisealbench import task, get_model


custom_task = task(
    "default",
    modality="audio",
    
    # Dataset options
    dataset_type="local",  # (Only supported in Audio) type: 'hf', 'local'
    dataset_dir="../examples/",
    audio_pattern="*.wav",
    # if dataset_type = 'hf'
    # dataset_name="hf-internal-testing/librispeech_asr_dummy",
    # dataset_hf_subset="clean",
    # dataset_split="validation[:4]",

    # Data loading options:
    sample_rate=24_000,
    padding_strategy="longest",  # Supported padding: 'fixed', 'longest'
    # If padding_strategy='fixed', we need to specify max_length to pad or truncate all audios to a fixed length
    # max_length=24000,  # 2 seconds of audio at 24kHz 
    batch_size=2,
    num_workers=2,
    attacks="all",
    
    # Output options:
    # result_dir="toy_result",  # Directory to store final output
    # cache_dir="to_wm",  # Directory where the intermediate watermaked contents and secret messges are stored
)

# If passing the path of a Python file, Omniseal Bench will search for build_model(), build_generator() or build_detector()
# in the file and and call that function
model = get_model("models/toy_audio.py", alpha=0.5, nbits=100, device="cuda")

avg_metrics, raw_results = custom_task(model)



  from .autonotebook import tqdm as notebook_tqdm


Running AudioWatermarkAttacksAndDetection with attack: no-attack
Running AudioWatermarkAttacksAndDetection with attack: updownresample
Running AudioWatermarkAttacksAndDetection with attack: bandpass_filter
Running AudioWatermarkAttacksAndDetection with attack: speed__speed_factor_0.5
Running AudioWatermarkAttacksAndDetection with attack: speed__speed_factor_0.6
Running AudioWatermarkAttacksAndDetection with attack: speed__speed_factor_0.7
Running AudioWatermarkAttacksAndDetection with attack: speed__speed_factor_0.8
Running AudioWatermarkAttacksAndDetection with attack: speed__speed_factor_0.9
Running AudioWatermarkAttacksAndDetection with attack: speed__speed_factor_1.0
Running AudioWatermarkAttacksAndDetection with attack: speed__speed_factor_1.1
Running AudioWatermarkAttacksAndDetection with attack: speed__speed_factor_1.2
Running AudioWatermarkAttacksAndDetection with attack: speed__speed_factor_1.25
Running AudioWatermarkAttacksAndDetection with attack: speed__speed_factor_1.3
Run

In [2]:
avg_metrics

{'watermark_det_score': AverageMetric(avg=0.375, count=32, square=0.140625, avg_ci_fn=None),
 'watermark_det': AverageMetric(avg=0.0, count=32, square=0.0, avg_ci_fn=None),
 'fake_det_score': AverageMetric(avg=0.375, count=32, square=0.140625, avg_ci_fn=None),
 'fake_det': AverageMetric(avg=0.0, count=32, square=0.0, avg_ci_fn=None),
 'bit_acc': AverageMetric(avg=0.421875, count=32, square=0.20703125, avg_ci_fn=None),
 'word_acc': AverageMetric(avg=0.0, count=32, square=0.0, avg_ci_fn=None),
 'p_value': AverageMetric(avg=0.7141342163085938, count=32, square=0.6318017898593098, avg_ci_fn=None),
 'capacity': AverageMetric(avg=1.6845126152038574, count=32, square=4.185459876364575, avg_ci_fn=None),
 'log10_p_value': AverageMetric(avg=-0.29078994494562627, count=32, square=0.3054619264695369, avg_ci_fn=None),
 'snr': AverageMetric(avg=58.6848121881485, count=32, square=3443.9447498994255, avg_ci_fn=None),
 'decoder_time': AverageMetric(avg=2.065990625, count=32, square=4.271666096562499, a