# High Resolution inference 

In [1]:
# run in the root of the repository
%load_ext autoreload
%autoreload 2
 
%cd ../..

/private/home/pfz/09-videoseal/videoseal-dev


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
from videoseal.utils.display import save_vid
from videoseal.utils import Timer
from videoseal.evals.full import setup_model_from_checkpoint
from videoseal.evals.metrics import bit_accuracy, pvalue, capacity, psnr, ssim, msssim, linf
from videoseal.data.datasets import VideoDataset
from videoseal.augmentation import Identity, H264, Crop
from videoseal.modules.jnd import JND, JNDSimplified

import os
from tqdm import tqdm
import torch
import gc

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda" 

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import os
from tqdm import tqdm
import torch

# Directory containing videos
video_dir = "/checkpoint/pfz/projects/videoseal/assets/videos/metamoviegen_3s"
base_output_folder = "outputs"
if not os.path.exists(base_output_folder):
    os.makedirs(base_output_folder)

# Example usage
ckpts = {
    "256b": "/checkpoint/pfz/2025_logs/0306_vseal_ydisc_release_bis/_nbits=256/checkpoint450.pth",
}

fps = 24 // 1
# frames_per_clip = fps * 3  # 3s
# frame_step = 1

# a timer to measure the time
timer = Timer()

# Iterate over all checkpoints
for model_name, ckpt in ckpts.items():
    wam = setup_model_from_checkpoint(ckpt)
    wam.eval()
    wam.to(device)

    wam.blender.scaling_w = 0.2

    # scaling_w = 0.3
    # attenuation = JND(in_channels=1, out_channels=1, blue=False)
    # wam.attenuation = attenuation
    # wam.blender.scaling_w = scaling_w

    # scaling_w = 1.0
    # attenuation = JNDSimplified(in_channels=1, out_channels=3, blue=True)
    # wam.attenuation = attenuation
    # wam.blender.scaling_w = scaling_w

    # wam.chunk_size = 200
    wam.step_size = 4
    wam.video_mode = "repeat"

    # Iterate over all video files in the directory
    video_files = [f for f in os.listdir(video_dir) if f.endswith(".mp4")][:3]

    for video_file in tqdm(video_files, desc=f"Processing Videos for {model_name}"):
        video_path = os.path.join(video_dir, video_file)
        base_name = os.path.splitext(video_file)[0]

        # Load video (assuming a function `load_video` exists)
        timer.start()
        vid, mask = VideoDataset.load_full_video_decord(video_path)
        print(f"loading video {video_path} - took {timer.stop():.2f}s")

        # Watermark embedding
        timer.start()
        outputs = wam.embed_lowres_attenuation(vid, is_video=True)
        print(f"embedding watermark  - took {timer.stop():.2f}s")

        # compute diff
        imgs = vid  # b c h w
        imgs_w = outputs["imgs_w"]  # b c h w
        msgs = outputs["msgs"]  # b k
        diff = imgs_w - imgs

        # # save
        timer.start()
        save_vid(imgs, f"{base_output_folder}/{model_name}_{base_name}_ori.mp4", fps)
        save_vid(imgs_w, f"{base_output_folder}/{model_name}_{base_name}_wm.mp4", fps)
        save_vid(10*diff.abs(), f"{base_output_folder}/{model_name}_{base_name}_diff.mp4", fps)

        # Metrics
        metrics = {
            "psnr": psnr(imgs, imgs_w, is_video=True).mean().item(),
            # "ssim": ssim(imgs, imgs_w).mean().item(),
            # "msssim": msssim(imgs, imgs_w).mean().item(),
            # "linf": linf(imgs, imgs_w).mean().item()
        }

        # Augment video
        print(f"compressing and detecting watermarks")
        for ii in range(4):
        # for ii in range(1):
            if ii == 0:
                imgs_aug = imgs_w
                label = "Original"
            if ii == 1: 
                imgs_aug, _ = H264()(imgs_w, crf=30)
                imgs_aug, _ = Crop()(imgs_aug, size=0.75)
                label = "H264 30 + Crop 0.8"
            if ii == 2: 
                imgs_aug, _ = H264()(imgs_w, crf=40)
                label = "H264 40"
            if ii == 3: 
                imgs_aug, _ = H264()(imgs_w, crf=50)
                label = "H264 50"

            # detect
            timer.start()
            aggregate = True
            if not aggregate:
                outputs = wam.detect(imgs_aug, is_video=True)
                preds = outputs["preds"]
                bit_preds = preds[:, 1:]  # b k ...
                bit_accuracy_ = bit_accuracy(
                    bit_preds,
                    msgs
                ).nanmean().item()
                metrics[f"bit_accuracy_{label.lower().replace(' ', '_').replace('.', '')}"] = bit_accuracy_
                # print(f"{label} - Bit Accuracy: {bit_accuracy_:.3f} - took {timer.stop():.2f}s")
            else:
                bit_preds = wam.detect_and_aggregate(imgs_aug)
                bit_accuracy_ = bit_accuracy(
                    bit_preds,
                    msgs[:1]
                ).nanmean().item()
                pvalue_ = pvalue(
                    bit_preds,
                    msgs[:1]
                ).nanmean().item()
                capacity_ = capacity(
                    bit_preds,
                    msgs[:1]
                ).nanmean().item()
                metrics[f"bit_accuracy_{label.lower().replace(' ', '_').replace('.', '')}"] = bit_accuracy_
                metrics[f"pvalue_{label.lower().replace(' ', '_').replace('.', '')}"] = pvalue_
                metrics[f"capacity_{label.lower().replace(' ', '_').replace('.', '')}"] = capacity_
                # print(f"{label} - Bit Accuracy: {bit_accuracy_:.3f} - P-Value: {pvalue_:0.2e} - Capacity: {capacity_:.3f} - took {timer.stop():.2f}s")
        print(metrics)

        del vid, outputs, imgs, imgs_w, diff

Model loaded successfully from /checkpoint/pfz/2025_logs/0306_vseal_ydisc_release_bis/_nbits=256/checkpoint450.pth with message: <All keys matched successfully>


Processing Videos for 256b:   0%|          | 0/3 [00:00<?, ?it/s]

loading video /checkpoint/pfz/projects/videoseal/assets/videos/metamoviegen_3s/01.mp4 - took 2.27s
embedding watermark  - took 2.01s
compressing and detecting watermarks
{'psnr': 45.84193420410156, 'bit_accuracy_original': 1.0, 'pvalue_original': 8.636168555094445e-78, 'capacity_original': 256.0, 'bit_accuracy_h264_30_+_crop_08': 0.921875, 'pvalue_h264_30_+_crop_08': 2.643924297128365e-48, 'capacity_h264_30_+_crop_08': 154.74232482910156, 'bit_accuracy_h264_40': 0.9296875, 'pvalue_h264_40': 1.7638096167789926e-50, 'capacity_h264_40': 162.0252685546875, 'bit_accuracy_h264_50': 0.5625, 'pvalue_h264_50': 0.02623583666607148, 'capacity_h264_50': 2.8929443359375}


Processing Videos for 256b:  33%|███▎      | 1/3 [00:52<01:44, 52.41s/it]

loading video /checkpoint/pfz/projects/videoseal/assets/videos/metamoviegen_3s/02.mp4 - took 2.82s
embedding watermark  - took 1.99s
compressing and detecting watermarks
{'psnr': 44.39072036743164, 'bit_accuracy_original': 1.0, 'pvalue_original': 8.636168555094445e-78, 'capacity_original': 256.0, 'bit_accuracy_h264_30_+_crop_08': 0.86328125, 'pvalue_h264_30_+_crop_08': 1.684380861872333e-34, 'capacity_h264_30_+_crop_08': 108.6513671875, 'bit_accuracy_h264_40': 0.9453125, 'pvalue_h264_40': 3.799763808759645e-55, 'capacity_h264_40': 177.66787719726562, 'bit_accuracy_h264_50': 0.6640625, 'pvalue_h264_50': 8.338613937483142e-08, 'capacity_h264_50': 20.255218505859375}


Processing Videos for 256b:  67%|██████▋   | 2/3 [01:40<00:49, 49.74s/it]

loading video /checkpoint/pfz/projects/videoseal/assets/videos/metamoviegen_3s/03.mp4 - took 2.90s
embedding watermark  - took 2.09s
compressing and detecting watermarks
{'psnr': 45.5481071472168, 'bit_accuracy_original': 1.0, 'pvalue_original': 8.636168555094445e-78, 'capacity_original': 256.0, 'bit_accuracy_h264_30_+_crop_08': 0.96484375, 'pvalue_h264_30_+_crop_08': 1.0114433457778881e-61, 'capacity_h264_30_+_crop_08': 199.77603149414062, 'bit_accuracy_h264_40': 0.9140625, 'pvalue_h264_40': 3.205950926783553e-46, 'capacity_h264_40': 147.77284240722656, 'bit_accuracy_h264_50': 0.60546875, 'pvalue_h264_50': 0.0004446015373390903, 'capacity_h264_50': 8.278640747070312}


Processing Videos for 256b: 100%|██████████| 3/3 [02:27<00:00, 49.22s/it]


In [None]:
outputs = wam.detect(imgs_aug, is_video=True)
outputs["preds"]

In [None]:
wam = setup_model_from_checkpoint(ckpt)
wam.eval()
wam.to(device)

video_path = "/private/home/pfz/09-videoseal/baselines/outputs/videoseal0.4_01_wm_crop_08_h264_40.mp4"
vid, mask = VideoDataset.load_full_video_decord(video_path)

timer.start()
bit_preds = wam.detect_and_aggregate(imgs_aug)
bit_accuracy_ = bit_accuracy(
    bit_preds,
    msgs[:1]
).nanmean().item()
pvalue_ = pvalue(
    bit_preds,
    msgs[:1]
).nanmean().item()
capacity_ = capacity(
    bit_preds,
    msgs[:1]
).nanmean().item()
print(f"CRF={crf} - Bit Accuracy: {bit_accuracy_:.3f} - P-Value: {pvalue_:0.2e} - Capacity: {capacity_:.3f} - took {timer.stop():.2f}s")

In [None]:
msgs

In [None]:
bit_preds