# High Resolution inference 

In [1]:
# run in the root of the repository
%load_ext autoreload
%autoreload 2
 
%cd ../..

/private/home/pfz/09-videoseal/baselines


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
from videoseal.utils.display import save_img
from videoseal.utils import Timer
from videoseal.evals.full import setup_model_from_checkpoint
from videoseal.evals.metrics import bit_accuracy
from videoseal.augmentation import Identity, JPEG
from videoseal.modules.jnd import JND

import os
import omegaconf
from tqdm import tqdm
import gc
from PIL import Image

import torch
import torchvision

to_tensor = torchvision.transforms.ToTensor()
to_pil = torchvision.transforms.ToPILImage()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu" 

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Directory containing videos
assets_dir = "/checkpoint/pfz/projects/videoseal/assets/imgs"
assets_dir = "/private/home/pfz/_images"
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)

# Checkpoint
ckpts = {
    # "hidden": '/private/home/hadyelsahar/work/code/videoseal/2024_logs/videoseal0.1/_lambda_d=0.5_lambda_i=0.0_optimizer=AdamW,lr=1e-4_videowam_step_size=4_video_start=500_embedder_model=hidden/checkpoint.pth',
    # "scaling_laws_smalldetector_tinyembedder":"/private/home/hadyelsahar/work/code/videoseal/2024_logs_large-exp/1105-videoseal0.2-scalinglaws/_lambda_d=0.0_extractor_model=sam_small_embedder_model=0/checkpoint.pth",
    # "scaling_laws_tinydetector_tinyembedder":"/private/home/hadyelsahar/work/code/videoseal/2024_logs_large-exp/1105-videoseal0.2-scalinglaws/_lambda_d=0.0_extractor_model=sam_tiny_embedder_model=0/checkpoint.pth",
    # "JND_fix_discloss":"/private/home/hadyelsahar/work/code/videoseal/2024_logs_large-exp/1108-videoseal0.2-discloss-fix-removeunused-params/_attenuation=jnd_3_3_nbits=64_lambda_d=0.5_video_start=100/checkpoint.pth",
    # "1111_discloss_sleepwake_highssim":"/private/home/hadyelsahar/work/code/videoseal/2024_logs_large-exp/1109-videoseal0.2-discloss-fix-hing-sleepwake-4nodes/_scaling_w=0.1_lambda_i=0.25_disc_hinge_on_logits_fake=False_sleepwake=True_video_start=500/checkpoint.pth"
    # "1111-finetuned":"/private/home/hadyelsahar/work/code/videoseal/2024_logs_large-exp/1111-videoseal0.2-sleepwake-resume/_attenuation=jnd_1_1/checkpoint.pth"
    # "1112-videoseal0.2":"/private/home/hadyelsahar/work/code/videoseal/2024_logs_large-exp/1111-videoseal0.2-archsearch-4nodes/_attenuation=None_nbits=64_finetune_detector_start=1000_embedder_model=unet_small2_quant/checkpoint.pth",
    # "1118-yuv-400":"/checkpoint/pfz/2024_logs/1118_vseal_long_sab/_scheduler=0_optimizer=adopt,lr=1e-5/checkpoint400.pth",
    # "1118-yuv-800":"/checkpoint/pfz/2024_logs/1118_vseal_long_sab/_scheduler=1_optimizer=AdamW,lr=1e-5/checkpoint.pth",
    # "trustmark": "baseline/trustmark",
    "videoseal0.1": '/private/home/hadyelsahar/work/code/videoseal/2024_logs/videoseal0.1/_lambda_d=0.5_lambda_i=0.5_optimizer=AdamW,lr=1e-4_videowam_step_size=4_video_start=500_embedder_model=unet_small2/checkpoint.pth',
    # "videoseal0.2b": "/private/home/hadyelsahar/work/code/videoseal/2024_logs_large-exp/1111-videoseal0.2-archsearch-4nodes/_attenuation=None_nbits=64_finetune_detector_start=800_embedder_model=unet_small2_quant/checkpoint.pth",
    # "videoseal0.2a": "/private/home/hadyelsahar/work/code/videoseal/2024_logs_large-exp/1109-videoseal0.2-discloss-fix-hing-sleepwake-4nodes/_scaling_w=0.5_lambda_i=0.5_disc_hinge_on_logits_fake=True_sleepwake=False_video_start=500/checkpoint.pth",
    # "videoseal0.4": "/large_experiments/meres/hadyelsahar/2024_logs/1120-videoseal0.4/_scaling_w=0.5_sleepwake=False_videowam_step_size=4_extractor_model=sam_tiny/checkpoint.pth",
    # "wam": "baseline/wam",
    # "cin": "baseline/cin",
    # "mbrs": "baseline/mbrs",
}

for ckpt_name, ckpt_path in ckpts.items():
    # ckpt_path = "/checkpoint/pfz/2024_logs/1028_vseal_long/_seed=3_optimizer=AdamW,lr=1e-4_embedder_model=unet_small2_yuv/checkpoint600.pth"

    # a timer to measure the time
    timer = Timer()

    # Iterate over all checkpoints
    wam = setup_model_from_checkpoint(ckpt_path)
    wam.eval()
    wam.to(device)

    # Iterate over all video files in the directory
    files = [f for f in os.listdir(assets_dir) if f.endswith(".png")]
    files = [os.path.join(assets_dir, f) for f in files]

    for file in tqdm(files, desc=f"Processing Images"):
        # load image
        imgs = Image.open(file, "r").convert("RGB")  # keep only rgb channels
        imgs = to_tensor(imgs).unsqueeze(0).float()

        # Watermark embedding
        timer.start()
        outputs = wam.embed(imgs, is_video=False)
        torch.cuda.synchronize()
        print(f"embedding watermark  - took {timer.stop():.2f}s")

        # compute diff
        imgs_w = outputs["imgs_w"]  # b c h w
        msgs = outputs["msgs"]  # b k
        diff = imgs_w - imgs

        # save
        timer.start()
        base_save_name = os.path.join(output_dir, os.path.basename(file).replace(".png", ""))
        print(f"saving videos to {base_save_name}")
        save_img(imgs[0], f"{base_save_name}_ori.png")
        save_img(imgs_w[0], f"{base_save_name}_wm.png")
        save_img(20*diff[0].abs(), f"{base_save_name}_diff.png")

        # Compute min and max values, reshape, and normalize
        min_vals = diff.view(imgs.shape[0], imgs.shape[1], -1).min(dim=2, keepdim=True)[0].view(imgs.shape[0], imgs.shape[1], 1, 1)
        max_vals = diff.view(imgs.shape[0], imgs.shape[1], -1).max(dim=2, keepdim=True)[0].view(imgs.shape[0], imgs.shape[1], 1, 1)
        normalized_images = (diff - min_vals) / (max_vals - min_vals)

        # Save the normalized video
        save_img(normalized_images[0], f"{base_save_name}_diff_norm.png")
        print(f"saving videos - took {timer.stop():.2f}s")

        # Augment video
        print(f"compressing and detecting watermarks")
        for qf in [80, 40]:
            imgs_aug, _ = JPEG()(imgs_w, None,qf)

            # detect
            timer.start()
            outputs = wam.detect(imgs_aug, is_video=True)
            preds = outputs["preds"]
            # print(preds)
            bit_preds = preds[:, 1:]  # b k ...
            bit_accuracy_ = bit_accuracy(
                bit_preds,
                msgs
            ).nanmean().item()
            print(f"bit accuracy at JPEG {qf} is {bit_accuracy_:.2f} - took {timer.stop():.2f}s")

        del outputs, imgs, imgs_w, diff, min_vals, max_vals, normalized_images

    # Free model from GPU
    del wam
    torch.cuda.empty_cache()

Model loaded successfully from /private/home/hadyelsahar/work/code/videoseal/2024_logs/videoseal0.1/_lambda_d=0.5_lambda_i=0.5_optimizer=AdamW,lr=1e-4_videowam_step_size=4_video_start=500_embedder_model=unet_small2/checkpoint.pth with message: _IncompatibleKeys(missing_keys=['rgb2yuv.M'], unexpected_keys=[])


Processing Images:   0%|          | 0/20 [00:00<?, ?it/s]

embedding watermark  - took 0.02s
saving videos to outputs/chao
saving videos - took 1.95s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 1.00 - took 0.02s


Processing Images:   5%|▌         | 1/20 [00:02<00:42,  2.26s/it]

bit accuracy at JPEG 40 is 0.97 - took 0.02s
embedding watermark  - took 0.01s
saving videos to outputs/corgi_avocado


Processing Images:  10%|█         | 2/20 [00:02<00:20,  1.16s/it]

saving videos - took 0.31s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 1.00 - took 0.01s
bit accuracy at JPEG 40 is 0.91 - took 0.01s
embedding watermark  - took 0.01s
saving videos to outputs/trex_bike


Processing Images:  15%|█▌        | 3/20 [00:03<00:17,  1.06s/it]

saving videos - took 0.80s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 1.00 - took 0.01s
bit accuracy at JPEG 40 is 0.97 - took 0.01s
embedding watermark  - took 0.01s
saving videos to outputs/tahiti


Processing Images:  20%|██        | 4/20 [00:04<00:18,  1.16s/it]

saving videos - took 1.12s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 0.98 - took 0.01s
bit accuracy at JPEG 40 is 0.97 - took 0.02s
embedding watermark  - took 0.01s
saving videos to outputs/tahiti_512


Processing Images:  25%|██▌       | 5/20 [00:05<00:13,  1.11it/s]

saving videos - took 0.36s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 1.00 - took 0.01s
bit accuracy at JPEG 40 is 0.88 - took 0.01s
embedding watermark  - took 0.01s
saving videos to outputs/tahiti_256
saving videos - took 0.11s
compressing and detecting watermarks


Processing Images:  30%|███       | 6/20 [00:05<00:09,  1.54it/s]

bit accuracy at JPEG 80 is 0.98 - took 0.01s
bit accuracy at JPEG 40 is 0.77 - took 0.01s
embedding watermark  - took 0.01s
saving videos to outputs/gauguin


Processing Images:  35%|███▌      | 7/20 [00:07<00:12,  1.00it/s]

saving videos - took 1.51s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 0.98 - took 0.02s
bit accuracy at JPEG 40 is 0.98 - took 0.02s
embedding watermark  - took 0.01s
saving videos to outputs/gauguin_512


Processing Images:  40%|████      | 8/20 [00:07<00:09,  1.21it/s]

saving videos - took 0.37s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 0.97 - took 0.02s
bit accuracy at JPEG 40 is 0.83 - took 0.02s
embedding watermark  - took 0.01s
saving videos to outputs/gauguin_256
saving videos - took 0.10s
compressing and detecting watermarks


Processing Images:  45%|████▌     | 9/20 [00:07<00:06,  1.61it/s]

bit accuracy at JPEG 80 is 0.86 - took 0.02s
bit accuracy at JPEG 40 is 0.73 - took 0.02s
embedding watermark  - took 0.01s
saving videos to outputs/hific


Processing Images:  50%|█████     | 10/20 [00:08<00:06,  1.59it/s]

saving videos - took 0.54s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 1.00 - took 0.02s
bit accuracy at JPEG 40 is 0.81 - took 0.02s
embedding watermark  - took 0.02s
saving videos to outputs/gfpgan


Processing Images:  55%|█████▌    | 11/20 [00:10<00:08,  1.08it/s]

saving videos - took 1.41s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 0.98 - took 0.02s
bit accuracy at JPEG 40 is 0.98 - took 0.02s
embedding watermark  - took 0.02s
saving videos to outputs/woman


Processing Images:  60%|██████    | 12/20 [00:11<00:08,  1.09s/it]

saving videos - took 1.25s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 1.00 - took 0.02s
bit accuracy at JPEG 40 is 1.00 - took 0.02s
embedding watermark  - took 0.02s
saving videos to outputs/gfpgan_hf


Processing Images:  65%|██████▌   | 13/20 [00:13<00:08,  1.24s/it]

saving videos - took 1.38s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 1.00 - took 0.02s
bit accuracy at JPEG 40 is 0.98 - took 0.02s
embedding watermark  - took 0.04s
saving videos to outputs/tahiti_photo
saving videos - took 3.66s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 1.00 - took 0.02s


Processing Images:  70%|███████   | 14/20 [00:17<00:12,  2.16s/it]

bit accuracy at JPEG 40 is 1.00 - took 0.02s
embedding watermark  - took 0.02s
saving videos to outputs/pope


Processing Images:  75%|███████▌  | 15/20 [00:19<00:10,  2.06s/it]

saving videos - took 1.61s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 1.00 - took 0.02s
bit accuracy at JPEG 40 is 0.98 - took 0.02s
embedding watermark  - took 0.01s
saving videos to outputs/tree-ring-bear


Processing Images:  80%|████████  | 16/20 [00:19<00:06,  1.57s/it]

saving videos - took 0.35s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 0.98 - took 0.02s
bit accuracy at JPEG 40 is 0.80 - took 0.02s
embedding watermark  - took 0.05s
saving videos to outputs/duck2
saving videos - took 1.33s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 1.00 - took 0.03s


Processing Images:  85%|████████▌ | 17/20 [00:21<00:05,  1.67s/it]

bit accuracy at JPEG 40 is 0.97 - took 0.03s
embedding watermark  - took 0.07s
saving videos to outputs/duck1
saving videos - took 1.27s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 0.98 - took 0.02s


Processing Images:  90%|█████████ | 18/20 [00:23<00:03,  1.72s/it]

bit accuracy at JPEG 40 is 0.83 - took 0.02s
embedding watermark  - took 0.07s
saving videos to outputs/ducks
saving videos - took 4.04s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 1.00 - took 0.02s


Processing Images:  95%|█████████▌| 19/20 [00:28<00:02,  2.64s/it]

bit accuracy at JPEG 40 is 1.00 - took 0.02s
embedding watermark  - took 0.03s
saving videos to outputs/videoseal
saving videos - took 2.43s
compressing and detecting watermarks
bit accuracy at JPEG 80 is 1.00 - took 0.02s


Processing Images: 100%|██████████| 20/20 [00:30<00:00,  1.55s/it]

bit accuracy at JPEG 40 is 1.00 - took 0.02s





## With attenuation

In [None]:
# Directory containing videos
assets_dir = "assets/imgs"
output_dir = "output"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Checkpoint
ckpt_path = "/checkpoint/pfz/2024_logs/1028_vseal_long/_seed=3_optimizer=AdamW,lr=1e-4_embedder_model=unet_small2_yuv/checkpoint600.pth"

# a timer to measure the time
timer = Timer()

# Iterate over all checkpoints
wam = setup_model_from_checkpoint(ckpt_path)
wam.eval()
wam.to(device)

# create attenuation
attenuation_cfg = "configs/attenuation.yaml"
attenuation = "jnd_1_1"
attenuation_cfg = omegaconf.OmegaConf.load(attenuation_cfg)[attenuation]
attenuation = JND(**attenuation_cfg)
wam.attenuation = attenuation
wam.scaling_w = 0.2

# Iterate over all video files in the directory
files = [f for f in os.listdir(assets_dir) if f.endswith(".png")]
files = [os.path.join(assets_dir, f) for f in files]

for file in tqdm(files, desc=f"Processing Videos"):
    # load image
    imgs = Image.open(file, "r").convert("RGB")  # keep only rgb channels
    imgs = to_tensor(imgs).unsqueeze(0).float()

    # Watermark embedding
    timer.start()
    outputs = wam.embed(imgs, is_video=False)
    # torch.cuda.synchronize()
    # print(f"embedding watermark  - took {timer.stop():.2f}s")

    # compute diff
    imgs_w = outputs["imgs_w"]  # b c h w
    msgs = outputs["msgs"]  # b k
    diff = imgs_w - imgs

    # save
    timer.start()
    base_save_name = os.path.join(output_dir, os.path.basename(file).replace(".png", ""))
    save_img(imgs[0], f"{base_save_name}_ori.png")
    save_img(imgs_w[0], f"{base_save_name}_wm.png")
    save_img(diff[0], f"{base_save_name}_diff.png")

    # Compute min and max values, reshape, and normalize
    min_vals = diff.view(imgs.shape[0], imgs.shape[1], -1).min(dim=2, keepdim=True)[0].view(imgs.shape[0], imgs.shape[1], 1, 1)
    max_vals = diff.view(imgs.shape[0], imgs.shape[1], -1).max(dim=2, keepdim=True)[0].view(imgs.shape[0], imgs.shape[1], 1, 1)
    normalized_images = (diff - min_vals) / (max_vals - min_vals)

    # Save the normalized video
    save_img(normalized_images[0], f"{base_save_name}_diff_norm.png")

    # Augment video
    for qf in [80, 40]:
        imgs_aug, _ = JPEG()(imgs_w, None,qf)

        # detect
        timer.start()
        outputs = wam.detect(imgs_aug, is_video=True)
        preds = outputs["preds"]
        bit_preds = preds[:, 1:]  # b k ...
        bit_accuracy_ = bit_accuracy(
            bit_preds,
            msgs
        ).nanmean().item()
        print(f"bit accuracy at JPEG {qf} is {bit_accuracy_:.2f} - took {timer.stop():.2f}s")

    del outputs, imgs, imgs_w, diff, min_vals, max_vals, normalized_images

# Free model from GPU
del wam
torch.cuda.empty_cache()