Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.

# Video Seal - Image inference

In [1]:
# run in the root of the repository
%load_ext autoreload
%autoreload 2
%cd ..

/private/home/pfz/09-videoseal/fbresearch-new


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [None]:
from videoseal.utils.display import save_img
from videoseal.utils import Timer
from videoseal.evals.full import setup_model_from_checkpoint
from videoseal.evals.metrics import bit_accuracy, psnr, ssim
from videoseal.augmentation import Identity, JPEG
from videoseal.modules.jnd import JND

import os
import omegaconf
from tqdm import tqdm
import gc
from PIL import Image

import torch
import torchvision

to_tensor = torchvision.transforms.ToTensor()
to_pil = torchvision.transforms.ToPILImage()

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu" 

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Directory containing videos
num_imgs = 10
assets_dir = "assets/imgs"
base_output_dir = "outputs"
os.makedirs(base_output_dir, exist_ok=True)

# Checkpoint
ckpts = {
    # "trustmark": "baseline/trustmark",
    # "model": "baseline/model",
    # "cin": "baseline/cin",
    # "mbrs": "baseline/mbrs",
    # "videoseal_0.0": 'videoseal_0.0',
    "videoseal": 'videoseal',
}

for ckpt_name, ckpt_path in ckpts.items():

    output_dir = os.path.join(base_output_dir, ckpt_name)
    os.makedirs(output_dir, exist_ok=True)

    # a timer to measure the time
    timer = Timer()

    # Iterate over all ckpts
    model = setup_model_from_checkpoint(ckpt_path)
    # model.blender.scaling_w = 0.2
    model.eval()
    model.compile()
    model.to(device)

    # Iterate over all video files in the directory
    files = [f for f in os.listdir(assets_dir) if f.endswith(".png") or f.endswith(".jpg")]
    files = [os.path.join(assets_dir, f) for f in files]
    files = files[:num_imgs]

    for file in tqdm(files, desc=f"Processing Images"):
        # load image
        imgs = Image.open(file, "r").convert("RGB")  # keep only rgb channels
        imgs = to_tensor(imgs).unsqueeze(0).float()

        # Watermark embedding
        timer.start()
        outputs = model.embed(imgs, is_video=False, lowres_attenuation=True)
        torch.cuda.synchronize()
        # print(f"embedding watermark  - took {timer.stop():.2f}s")

        # compute diff
        imgs_w = outputs["imgs_w"]  # b c h w
        msgs = outputs["msgs"]  # b k
        diff = imgs_w - imgs

        # save
        timer.start()
        base_save_name = os.path.join(output_dir, os.path.basename(file).replace(".png", ""))
        # print(f"saving videos to {base_save_name}")
        save_img(imgs[0], f"{base_save_name}_ori.png")
        save_img(imgs_w[0], f"{base_save_name}_wm.png")
        save_img(20*diff[0].abs(), f"{base_save_name}_diff.png")

        # Compute min and max values, reshape, and normalize
        min_vals = diff.view(imgs.shape[0], imgs.shape[1], -1).min(dim=2, keepdim=True)[0].view(imgs.shape[0], imgs.shape[1], 1, 1)
        max_vals = diff.view(imgs.shape[0], imgs.shape[1], -1).max(dim=2, keepdim=True)[0].view(imgs.shape[0], imgs.shape[1], 1, 1)
        normalized_images = (diff - min_vals) / (max_vals - min_vals)

        # Save the normalized video
        save_img(normalized_images[0], f"{base_save_name}_diff_norm.png")
        # print(f"saving videos - took {timer.stop():.2f}s")

        # Metrics
        imgs_aug = imgs_w
        outputs = model.detect(imgs_aug, is_video=False)
        metrics = {
            "file": file,
            "bit_accuracy": bit_accuracy(
                outputs["preds"][:, 1:],
                msgs
            ).nanmean().item(),
            "psnr": psnr(imgs_w, imgs).item(),
            "ssim": ssim(imgs_w, imgs).item()
        }

        # Augment video
        # print(f"compressing and detecting watermarks")
        for qf in [80, 40]:
            imgs_aug, _ = JPEG()(imgs_w, None,qf)

            # detect
            timer.start()
            outputs = model.detect(imgs_aug, is_video=True)
            preds = outputs["preds"]
            # print(preds)
            bit_preds = preds[:, 1:]  # b k ...
            bit_accuracy_ = bit_accuracy(
                bit_preds,
                msgs
            ).nanmean().item()
            
            metrics[f"bit_accuracy_qf{qf}"] = bit_accuracy_

        print(metrics)

        del outputs, imgs, imgs_w, diff, min_vals, max_vals, normalized_images

    # Free model from GPU
    del model
    torch.cuda.empty_cache()

File /private/home/pfz/09-videoseal/fbresearch-new/ckpts/y_256b_img.pth exists, skipping download


Model loaded successfully from /private/home/pfz/09-videoseal/fbresearch-new/ckpts/y_256b_img.pth with message: <All keys matched successfully>


Processing Images: 100%|██████████| 1/1 [00:02<00:00,  2.46s/it]

{'file': 'assets/imgs/1.jpg', 'bit_accuracy': 1.0, 'psnr': 48.22146987915039, 'ssim': 0.9974438548088074, 'bit_accuracy_qf80': 1.0, 'bit_accuracy_qf40': 0.98828125}



