# High Resolution inference 

In [2]:
# run in the root of the repository
%load_ext autoreload
%autoreload 2
 
%cd ../..

/private/home/pfz/09-videoseal/videoseal-dev


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [3]:
from videoseal.utils.display import save_img
from videoseal.utils import Timer
from videoseal.evals.full import setup_model_from_checkpoint
from videoseal.evals.metrics import bit_accuracy, psnr, ssim
from videoseal.augmentation import Identity, JPEG
from videoseal.modules.jnd import JND, VarianceBasedJND

import os
import omegaconf
from tqdm import tqdm
import gc
from PIL import Image

import torch
import torchvision

to_tensor = torchvision.transforms.ToTensor()
to_pil = torchvision.transforms.ToPILImage()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu" 

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Directory containing videos
num_imgs = 10
assets_dir = "/checkpoint/pfz/projects/videoseal/assets/imgs"
assets_dir = "/large_experiments/omniseal/sa-1b/val"
# assets_dir = "/private/home/pfz/_images"
base_output_dir = "outputs"
# base_output_dir = "/checkpoint/pfz/2025_logs/0206_vseal_rgb_y_images_for_s/att"
os.makedirs(base_output_dir, exist_ok=True)

# Checkpoint
ckpts = {
    # "videoseal0.1": '/private/home/hadyelsahar/work/code/videoseal/2024_logs/videoseal0.1/_lambda_d=0.5_lambda_i=0.5_optimizer=AdamW,lr=1e-4_videowam_step_size=4_video_start=500_embedder_model=unet_small2/checkpoint.pth',
    # "videoseal0.2b": "/private/home/hadyelsahar/work/code/videoseal/2024_logs_large-exp/1111-videoseal0.2-archsearch-4nodes/_attenuation=None_nbits=64_finetune_detector_start=800_embedder_model=unet_small2_quant/checkpoint.pth",
    # "videoseal0.2a": "/private/home/hadyelsahar/work/code/videoseal/2024_logs_large-exp/1109-videoseal0.2-discloss-fix-hing-sleepwake-4nodes/_scaling_w=0.5_lambda_i=0.5_disc_hinge_on_logits_fake=True_sleepwake=False_video_start=500/checkpoint.pth",
    # "videoseal0.4": "/large_experiments/meres/hadyelsahar/2024_logs/1120-videoseal0.4/_scaling_w=0.5_sleepwake=False_videowam_step_size=4_extractor_model=sam_tiny/checkpoint.pth",
    # "trustmark": "baseline/trustmark",
    # "wam": "baseline/wam",
    # "cin": "baseline/cin",
    # "mbrs": "baseline/mbrs",
    # "rgb": "/checkpoint/pfz/2025_logs/0206_vseal_rgb_y_64bits_lessdisc/_lambda_d=0.1_optimizer=AdamW,lr=1e-4_embedder_model=1/checkpoint.pth",
    # "y": "/checkpoint/pfz/2025_logs/0207_vseal_y_64bits_scalingw_schedule/_scaling_w_schedule=0_scaling_w=0.1/checkpoint650.pth",
    "96b_y": "/checkpoint/pfz/2025_logs/0219_vseal_convnextextractor/_nbits=96_lambda_i=0.1_embedder_model=1/checkpoint600.pth",
    "96b_y_400": "/checkpoint/pfz/2025_logs/0219_vseal_convnextextractor/_nbits=96_lambda_i=0.1_embedder_model=1/checkpoint400.pth",
}

for ckpt_name, ckpt_path in ckpts.items():

    output_dir = os.path.join(base_output_dir, ckpt_name)
    os.makedirs(output_dir, exist_ok=True)

    # a timer to measure the time
    timer = Timer()

    # Iterate over all checkpoints
    wam = setup_model_from_checkpoint(ckpt_path)
    wam.eval()
    wam.to(device)

    # attenuation = VarianceBasedJND(
    #     mode="variance",
    #     max_variance_value_for_clipping=300,
    #     min_heatmap_value=0.1,
    #     avg_pool_kernel_size=3
    # )
    # attenuation = JND(
    #     in_channels=1,
    #     out_channels=1,
    # )
    # wam.attenuation = attenuation
    wam.blender.scaling_w = 0.016

    # Iterate over all video files in the directory
    files = [f for f in os.listdir(assets_dir) if f.endswith(".png") or f.endswith(".jpg")]
    files = [os.path.join(assets_dir, f) for f in files]
    files = files[:num_imgs]

    for file in tqdm(files, desc=f"Processing Images"):
        # load image
        imgs = Image.open(file, "r").convert("RGB")  # keep only rgb channels
        imgs = to_tensor(imgs).unsqueeze(0).float()

        # Watermark embedding
        timer.start()
        outputs = wam.embed(imgs, is_video=False)
        torch.cuda.synchronize()
        # print(f"embedding watermark  - took {timer.stop():.2f}s")

        # compute diff
        imgs_w = outputs["imgs_w"]  # b c h w
        msgs = outputs["msgs"]  # b k
        diff = imgs_w - imgs

        # save
        timer.start()
        base_save_name = os.path.join(output_dir, os.path.basename(file).replace(".png", ""))
        # print(f"saving videos to {base_save_name}")
        save_img(imgs[0], f"{base_save_name}_ori.png")
        save_img(imgs_w[0], f"{base_save_name}_wm.png")
        save_img(20*diff[0].abs(), f"{base_save_name}_diff.png")

        # Compute min and max values, reshape, and normalize
        min_vals = diff.view(imgs.shape[0], imgs.shape[1], -1).min(dim=2, keepdim=True)[0].view(imgs.shape[0], imgs.shape[1], 1, 1)
        max_vals = diff.view(imgs.shape[0], imgs.shape[1], -1).max(dim=2, keepdim=True)[0].view(imgs.shape[0], imgs.shape[1], 1, 1)
        normalized_images = (diff - min_vals) / (max_vals - min_vals)

        # Save the normalized video
        save_img(normalized_images[0], f"{base_save_name}_diff_norm.png")
        # print(f"saving videos - took {timer.stop():.2f}s")

        # Metrics
        imgs_aug = imgs_w
        outputs = wam.detect(imgs_aug, is_video=False)
        metrics = {
            "bit_accuracy": bit_accuracy(
                outputs["preds"][:, 1:],
                msgs
            ).nanmean().item(),
            "psnr": psnr(imgs_w, imgs).item(),
            "ssim": ssim(imgs_w, imgs).item()
        }

        # Augment video
        # print(f"compressing and detecting watermarks")
        for qf in [80, 40]:
            imgs_aug, _ = JPEG()(imgs_w, None,qf)

            # detect
            timer.start()
            outputs = wam.detect(imgs_aug, is_video=True)
            preds = outputs["preds"]
            # print(preds)
            bit_preds = preds[:, 1:]  # b k ...
            bit_accuracy_ = bit_accuracy(
                bit_preds,
                msgs
            ).nanmean().item()
            
            metrics[f"bit_accuracy_qf{qf}"] = bit_accuracy_

        print(metrics)

        del outputs, imgs, imgs_w, diff, min_vals, max_vals, normalized_images

    # Free model from GPU
    del wam
    torch.cuda.empty_cache()

Model loaded successfully from /checkpoint/pfz/2025_logs/0219_vseal_convnextextractor/_nbits=96_lambda_i=0.1_embedder_model=1/checkpoint600.pth with message: <All keys matched successfully>


Processing Images:  10%|█         | 1/10 [00:03<00:33,  3.71s/it]

{'bit_accuracy': 1.0, 'psnr': 39.8629264831543, 'ssim': 0.9945476651191711, 'bit_accuracy_qf80': 1.0, 'bit_accuracy_qf40': 1.0}


Processing Images:  20%|██        | 2/10 [00:07<00:30,  3.84s/it]

{'bit_accuracy': 1.0, 'psnr': 40.429771423339844, 'ssim': 0.9964190125465393, 'bit_accuracy_qf80': 1.0, 'bit_accuracy_qf40': 1.0}


Processing Images:  30%|███       | 3/10 [00:12<00:29,  4.26s/it]

{'bit_accuracy': 1.0, 'psnr': 39.848846435546875, 'ssim': 0.9975709915161133, 'bit_accuracy_qf80': 1.0, 'bit_accuracy_qf40': 1.0}


Processing Images:  40%|████      | 4/10 [00:16<00:25,  4.25s/it]

{'bit_accuracy': 1.0, 'psnr': 39.64069747924805, 'ssim': 0.996229887008667, 'bit_accuracy_qf80': 1.0, 'bit_accuracy_qf40': 1.0}


Processing Images:  50%|█████     | 5/10 [00:20<00:21,  4.20s/it]

{'bit_accuracy': 1.0, 'psnr': 40.11200714111328, 'ssim': 0.9972754120826721, 'bit_accuracy_qf80': 1.0, 'bit_accuracy_qf40': 1.0}


Processing Images:  60%|██████    | 6/10 [00:25<00:17,  4.26s/it]

{'bit_accuracy': 1.0, 'psnr': 39.58835983276367, 'ssim': 0.9928290843963623, 'bit_accuracy_qf80': 1.0, 'bit_accuracy_qf40': 1.0}


Processing Images:  70%|███████   | 7/10 [00:29<00:12,  4.22s/it]

{'bit_accuracy': 1.0, 'psnr': 42.56237030029297, 'ssim': 0.9980538487434387, 'bit_accuracy_qf80': 1.0, 'bit_accuracy_qf40': 1.0}


Processing Images:  80%|████████  | 8/10 [00:32<00:07,  3.95s/it]

{'bit_accuracy': 1.0, 'psnr': 40.15971374511719, 'ssim': 0.9958235621452332, 'bit_accuracy_qf80': 1.0, 'bit_accuracy_qf40': 1.0}


Processing Images:  90%|█████████ | 9/10 [00:37<00:04,  4.11s/it]

{'bit_accuracy': 1.0, 'psnr': 40.461456298828125, 'ssim': 0.9953904151916504, 'bit_accuracy_qf80': 1.0, 'bit_accuracy_qf40': 1.0}


Processing Images: 100%|██████████| 10/10 [00:40<00:00,  4.08s/it]

{'bit_accuracy': 1.0, 'psnr': 39.8139762878418, 'ssim': 0.9935389161109924, 'bit_accuracy_qf80': 1.0, 'bit_accuracy_qf40': 1.0}





Model loaded successfully from /checkpoint/pfz/2025_logs/0219_vseal_convnextextractor/_nbits=96_lambda_i=0.1_embedder_model=1/checkpoint400.pth with message: <All keys matched successfully>


Processing Images:  10%|█         | 1/10 [00:03<00:30,  3.37s/it]

{'bit_accuracy': 0.5208333134651184, 'psnr': 46.75999450683594, 'ssim': 0.9991652369499207, 'bit_accuracy_qf80': 0.5104166865348816, 'bit_accuracy_qf40': 0.53125}


Processing Images:  20%|██        | 2/10 [00:06<00:28,  3.52s/it]

{'bit_accuracy': 0.6875, 'psnr': 47.262420654296875, 'ssim': 0.9994601607322693, 'bit_accuracy_qf80': 0.7708333134651184, 'bit_accuracy_qf40': 0.6979166865348816}


Processing Images:  30%|███       | 3/10 [00:11<00:27,  4.00s/it]

{'bit_accuracy': 0.4479166567325592, 'psnr': 45.9100227355957, 'ssim': 0.9995020031929016, 'bit_accuracy_qf80': 0.4375, 'bit_accuracy_qf40': 0.4270833432674408}


Processing Images:  40%|████      | 4/10 [00:15<00:23,  3.97s/it]

{'bit_accuracy': 0.5208333134651184, 'psnr': 46.37257766723633, 'ssim': 0.9990463852882385, 'bit_accuracy_qf80': 0.53125, 'bit_accuracy_qf40': 0.5104166865348816}


Processing Images:  50%|█████     | 5/10 [00:19<00:19,  3.91s/it]

{'bit_accuracy': 0.53125, 'psnr': 45.80270004272461, 'ssim': 0.999504566192627, 'bit_accuracy_qf80': 0.4479166567325592, 'bit_accuracy_qf40': 0.5208333134651184}


Processing Images:  60%|██████    | 6/10 [00:23<00:15,  3.98s/it]

{'bit_accuracy': 0.5, 'psnr': 45.30924606323242, 'ssim': 0.9989807605743408, 'bit_accuracy_qf80': 0.5104166865348816, 'bit_accuracy_qf40': 0.5}


Processing Images:  70%|███████   | 7/10 [00:27<00:11,  3.97s/it]

{'bit_accuracy': 0.7916666865348816, 'psnr': 46.6107063293457, 'ssim': 0.9994402527809143, 'bit_accuracy_qf80': 0.8125, 'bit_accuracy_qf40': 0.6666666865348816}


Processing Images:  80%|████████  | 8/10 [00:30<00:07,  3.75s/it]

{'bit_accuracy': 0.5625, 'psnr': 45.93770980834961, 'ssim': 0.9991415143013, 'bit_accuracy_qf80': 0.5833333134651184, 'bit_accuracy_qf40': 0.625}


Processing Images:  90%|█████████ | 9/10 [00:34<00:03,  3.92s/it]

{'bit_accuracy': 0.5416666865348816, 'psnr': 48.0381965637207, 'ssim': 0.9993906617164612, 'bit_accuracy_qf80': 0.5416666865348816, 'bit_accuracy_qf40': 0.5520833134651184}


Processing Images: 100%|██████████| 10/10 [00:38<00:00,  3.85s/it]

{'bit_accuracy': 0.65625, 'psnr': 44.63625717163086, 'ssim': 0.9983169436454773, 'bit_accuracy_qf80': 0.6979166865348816, 'bit_accuracy_qf40': 0.6041666865348816}





## With attenuation

In [None]:
# Directory containing videos
assets_dir = "/checkpoint/pfz/projects/videoseal/assets/imgs"
output_dir = "outputs"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Checkpoint
ckpt_path = "videoseal"
# ckpt_path = "/checkpoint/pfz/2025_logs/0115_vseal_rgb_96bits_nopercep_yuv/_scaling_w=0.05_lambda_d=0.5_extractor_model=sam_small/checkpoint.pth"

# a timer to measure the time
timer = Timer()

# Iterate over all checkpoints
wam = setup_model_from_checkpoint(ckpt_path)
wam.eval()
wam.to(device)

# create attenuation
attenuation = VarianceBasedJND(
    mode="variance",
    max_variance_value_for_clipping=300,
    min_heatmap_value=0.1,
    avg_pool_kernel_size=3
)
wam.attenuation = attenuation
wam.blender.scaling_w = 20.0

# Iterate over all video files in the directory
files = [f for f in os.listdir(assets_dir) if f.endswith(".png")]
files = [os.path.join(assets_dir, f) for f in files]

for file in tqdm(files, desc=f"Processing Images"):
    # load image
    imgs = Image.open(file, "r").convert("RGB")  # keep only rgb channels
    imgs = to_tensor(imgs).unsqueeze(0).float()

    # Watermark embedding
    timer.start()
    outputs = wam.embed(imgs, is_video=False)
    # torch.cuda.synchronize()
    # print(f"embedding watermark  - took {timer.stop():.2f}s")

    # compute diff
    imgs_w = outputs["imgs_w"]  # b c h w
    msgs = outputs["msgs"]  # b k
    diff = imgs_w - imgs

    # save
    timer.start()
    base_save_name = os.path.join(output_dir, os.path.basename(file).replace(".png", ""))
    save_img(imgs[0], f"{base_save_name}_ori.png")
    save_img(imgs_w[0], f"{base_save_name}_wm.png")
    save_img(diff[0], f"{base_save_name}_diff.png")

    # Compute min and max values, reshape, and normalize
    min_vals = diff.view(imgs.shape[0], imgs.shape[1], -1).min(dim=2, keepdim=True)[0].view(imgs.shape[0], imgs.shape[1], 1, 1)
    max_vals = diff.view(imgs.shape[0], imgs.shape[1], -1).max(dim=2, keepdim=True)[0].view(imgs.shape[0], imgs.shape[1], 1, 1)
    normalized_images = (diff - min_vals) / (max_vals - min_vals)

    # Save the normalized video
    save_img(normalized_images[0], f"{base_save_name}_diff_norm.png")

    # Metrics
    imgs_aug = imgs_w
    outputs = wam.detect(imgs_aug, is_video=False)
    metrics = {
        "bit_accuracy": bit_accuracy(
            outputs["preds"][:, 1:],
            msgs
        ).nanmean().item(),
        "psnr": psnr(imgs_w, imgs).item(),
        "ssim": ssim(imgs_w, imgs).item()
    }

    # Augment video
    for qf in [80, 40]:
        imgs_aug, _ = JPEG()(imgs_w, None,qf)

        # detect
        timer.start()
        outputs = wam.detect(imgs_aug, is_video=False)
        preds = outputs["preds"]
        bit_preds = preds[:, 1:]  # b k ...
        bit_accuracy_ = bit_accuracy(
            bit_preds,
            msgs
        ).nanmean().item()
        metrics[f"bit_accuracy_qf_{qf}"] = bit_accuracy_
    
    print(metrics)
    del outputs, imgs, imgs_w, diff, min_vals, max_vals, normalized_images

# Free model from GPU
del wam
torch.cuda.empty_cache()