In [1]:
# run in the root of the repository
%cd ../..

/private/home/pfz/09-videoseal/videoseal-dev


In [2]:
import json
import argparse
import os
import omegaconf
import numpy as np
import imageio
import cv2

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms

from PIL import Image
from skimage.metrics import peak_signal_noise_ratio

from videoseal.models import Wam, build_embedder, build_extractor
from videoseal.augmentation.augmenter import Augmenter
from videoseal.data.transforms import default_transform, normalize_img, unnormalize_img
from videoseal.data.datasets import VideoDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load and build models 

In [9]:
# load video
video_dir = "assets/videos"
video_path = "assets/videos/sav_013754.mp4"
# !ffprobe -v error -show_entries stream=r_frame_rate -of default=noprint_wrappers=1:nokey=1 assets/videos/sav_013754.mp4

fps = 24 // 1
frames_per_clip = fps * 3 # 3s
frame_step = 1

vid_dataset = VideoDataset(
    folder_paths = [video_dir], 
    frames_per_clip = frames_per_clip,
    frame_step = frame_step
)
vid = vid_dataset.__getitem__(0)
video_tensor = vid[0][0]
video_tensor = np.transpose(video_tensor, (0, 3, 1, 2))
video_tensor = torch.tensor(video_tensor, dtype=torch.float32)
print(f"Video tensor shape: {video_tensor.shape}")

INFO:videoseal.data.datasets:Loading videos from assets/videos
INFO:videoseal.data.datasets:Found 2 videos in assets/videos
Processing videos in assets/videos: 100%|██████████| 2/2 [00:00<00:00, 33026.02it/s]
INFO:videoseal.data.datasets:Total videos loaded from assets/videos: 2


Video tensor shape: torch.Size([72, 3, 1920, 1080])


In [10]:
from videoseal.augmentation.augmenter import Augmenter
from videoseal.modules.jnd import JND
from videoseal.models.embedder import Embedder
from videoseal.models.extractor import Extractor

class VideoWam(nn.Module):
    wm_threshold: float = 0.0
    image_format: str = "RGB"

    def __init__(
        self,
        embedder: Embedder,
        detector: Extractor,
        augmenter: Augmenter,
        attenuation: JND = None,
        scaling_w: float = 1.0,
        scaling_i: float = 1.0,
        img_size: int = 256,
        chunk_size: int = 8,
        step_size: int = 4,
        device: str = device,
    ) -> None:
        """
        WAM (watermark-anything models) model that combines an embedder, a detector, and an augmenter.
        Embeds a message into an image and detects it as a mask.

        Arguments:
            embedder: The watermark embedder
            detector: The watermark detector
            augmenter: The image augmenter
            attenuation: The JND model to attenuate the watermark distortion
            scaling_w: The scaling factor for the watermark
            scaling_i: The scaling factor for the image
        """
        super().__init__()
        # modules
        self.embedder = embedder
        self.detector = detector
        self.augmenter = augmenter
        self.attenuation = attenuation
        # scalings
        self.scaling_w = scaling_w
        self.scaling_i = scaling_i
        # video settings
        self.chunk_size = chunk_size  # encode 8 imgs at a time
        self.step_size = step_size  # propagate the wm to 4 next imgs
        self.resize_to = transforms.Resize(img_size, antialias=True)
        # device
        self.device = device

    def get_random_msg(self, bsz: int = 1, nb_repetitions=1) -> torch.Tensor:
        return self.embedder.get_random_msg(bsz, nb_repetitions)  # b x k

    @torch.no_grad()
    def embed_inference(
        self,
        imgs: torch.Tensor,
        msg: torch.Tensor = None,
    ):
        """ 
        Does the forward pass of the encoder only.
        Rescale the watermark signal by a JND (just noticeable difference heatmap) that says where pixel can be changed without being noticed.
        The watermark signal is computed on the image downsampled to 256x... pixels, and then upsampled to the original size.
        The watermark signal is computed every step_size imgs and propagated to the next step_size imgs.

        Args:
            imgs: (torch.Tensor) Batched images with shape FxCxHxW
            msg: (torch.Tensor) Batched messages with shape 1xL
        """
        if msg is None:
            msg = self.get_random_msg()

        # encode by chunk of 8 imgs, propagate the wm to 4 next imgs
        chunk_size = self.chunk_size  # n
        step_size = self.step_size
        msg = msg.repeat(chunk_size, 1).to(self.device) # 1 k -> n k

        # initialize watermarked imgs
        imgs_w = torch.zeros_like(imgs) # f 3 h w

        for ii in range(0, len(imgs[::step_size]), chunk_size):
            nimgs_in_ck = min(chunk_size, len(imgs[::step_size]) - ii)
            start = ii*step_size
            end = start + nimgs_in_ck * step_size
            all_imgs_in_ck = imgs[start : end, ...].to(self.device) # f 3 h w

            # choose one frame every step_size
            imgs_in_ck = all_imgs_in_ck[::step_size] # n 3 h w
            # downsampling with fixed short edge
            imgs_in_ck = self.resize_to(imgs_in_ck) # n 3 wm_h wm_w
            # deal with last chunk that may have less than chunk_size frames
            if nimgs_in_ck < chunk_size:  
                msg = msg[:nimgs_in_ck]
            
            # get deltas for the chunk, and repeat them for each frame in the chunk
            deltas_in_ck = self.embedder(imgs_in_ck, msg) # n 3 wm_h wm_w
            deltas_in_ck = torch.repeat_interleave(deltas_in_ck, step_size, dim=0) # f 3 wm_h wm_w
            deltas_in_ck = deltas_in_ck[:len(all_imgs_in_ck)] # at the end of video there might be more deltas than needed
            
            # upsampling
            deltas_in_ck = nn.functional.interpolate(deltas_in_ck, size=imgs.shape[-2:], mode='bilinear', align_corners=True)
            
            # create watermarked imgs
            all_imgs_in_ck_w = self.scaling_i * all_imgs_in_ck + self.scaling_w * deltas_in_ck
            if self.attenuation is not None:
                all_imgs_in_ck_w = self.attenuation(all_imgs_in_ck, all_imgs_in_ck_w)
            imgs_w[start : end, ...] = all_imgs_in_ck_w.cpu() # n 3 h w

        return imgs_w

    @torch.no_grad()
    def detect_inference(
        self,
        imgs: torch.Tensor,
    ):
        """
        ...
        
        Args:
            imgs: (torch.Tensor) Batched images with shape FxCxHxW
        """
        ....


In [11]:
def load_model_from_checkpoint(exp_dir, exp_name):
    logfile_path = os.path.join(exp_dir, 'logs', exp_name + '.stdout')
    ckpt_path = os.path.join(exp_dir, exp_name, 'checkpoint.pth')

    # Load parameters from log file
    with open(logfile_path, 'r') as file:
        for line in file:
            if '__log__:' in line:
                params = json.loads(line.split('__log__:')[1].strip())
                break

    # Create an argparse Namespace object from the parameters
    args = argparse.Namespace(**params)
    print(args)
    
    # Load configurations
    for path in [args.embedder_config, args.extractor_config, args.augmentation_config]:
        path = os.path.join(exp_dir, "code", path)
    # embedder
    embedder_cfg = omegaconf.OmegaConf.load(args.embedder_config)
    args.embedder_model = args.embedder_model or embedder_cfg.model
    embedder_params = embedder_cfg[args.embedder_model]
    # extractor
    extractor_cfg = omegaconf.OmegaConf.load(args.extractor_config)
    args.extractor_model = args.extractor_model or extractor_cfg.model
    extractor_params = extractor_cfg[args.extractor_model]
    # augmenter
    augmenter_cfg = omegaconf.OmegaConf.load(args.augmentation_config)
    
    # Build models
    embedder = build_embedder(args.embedder_model, embedder_params, args.nbits)
    extractor = build_extractor(extractor_cfg.model, extractor_params, args.img_size_extractor, args.nbits)
    augmenter = Augmenter(**augmenter_cfg)
    
    # Build the complete model
    wam = VideoWam(embedder, extractor, augmenter, 
                   scaling_w=args.scaling_w, scaling_i=args.scaling_i)
    
    # Load the model weights
    if os.path.exists(ckpt_path):
        checkpoint = torch.load(ckpt_path, map_location='cpu')
        wam.load_state_dict(checkpoint['model'])
        print("Model loaded successfully from", ckpt_path)
        print(line)
    else:
        print("Checkpoint path does not exist:", ckpt_path)
    
    return wam

# Example usage
exp_dir = '/checkpoint/pfz/2024_logs/0911_vseal_pw'
exp_name = '_extractor_model=sam_tiny'

wam = load_model_from_checkpoint(exp_dir, exp_name)
wam.eval()
wam.to(device)

Namespace(train_dir='/datasets01/COCO/060817/train2014/', train_annotation_file='/datasets01/COCO/060817/annotations/instances_train2014.json', val_dir='/datasets01/COCO/060817/val2014/', val_annotation_file='/datasets01/COCO/060817/annotations/instances_val2014.json', output_dir='/checkpoint/pfz/2024_logs/0911_vseal_pw/_extractor_model=sam_tiny', embedder_config='configs/embedder.yaml', augmentation_config='configs/simple_augs.yaml', extractor_config='configs/extractor.yaml', attenuation_config='configs/attenuation.yaml', embedder_model='unet_small2', extractor_model='sam_tiny', nbits=32, img_size=256, img_size_extractor=256, attenuation='None', scaling_w=0.4, scaling_w_schedule=None, scaling_i=1.0, threshold_mask=0.6, optimizer='AdamW,lr=1e-4', optimizer_d=None, scheduler='CosineLRScheduler,lr_min=1e-6,t_initial=100,warmup_lr_init=1e-6,warmup_t=5', epochs=100, batch_size=16, batch_size_eval=32, temperature=1.0, workers=8, resume_from=None, lambda_det=0.0, lambda_dec=1.0, lambda_i=0.0

  checkpoint = torch.load(ckpt_path, map_location='cpu')


Model loaded successfully from /checkpoint/pfz/2024_logs/0911_vseal_pw/_extractor_model=sam_tiny/checkpoint.pth
__log__:{"train_dir": "/datasets01/COCO/060817/train2014/", "train_annotation_file": "/datasets01/COCO/060817/annotations/instances_train2014.json", "val_dir": "/datasets01/COCO/060817/val2014/", "val_annotation_file": "/datasets01/COCO/060817/annotations/instances_val2014.json", "output_dir": "/checkpoint/pfz/2024_logs/0911_vseal_pw/_extractor_model=sam_tiny", "embedder_config": "configs/embedder.yaml", "augmentation_config": "configs/simple_augs.yaml", "extractor_config": "configs/extractor.yaml", "attenuation_config": "configs/attenuation.yaml", "embedder_model": "unet_small2", "extractor_model": "sam_tiny", "nbits": 32, "img_size": 256, "img_size_extractor": 256, "attenuation": "None", "scaling_w": 0.4, "scaling_w_schedule": null, "scaling_i": 1.0, "threshold_mask": 0.6, "optimizer": "AdamW,lr=1e-4", "optimizer_d": null, "scheduler": "CosineLRScheduler,lr_min=1e-6,t_initi

VideoWam(
  (embedder): UnetEmbedder(
    (unet): UNetMsg(
      (msg_processor): MsgProcessor(
        (msg_embeddings): Embedding(64, 64)
      )
      (inc): ResnetBlock(
        (double_conv): Sequential(
          (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): ChanRMSNorm()
          (2): SiLU()
          (3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (4): ChanRMSNorm()
          (5): SiLU()
        )
        (res_conv): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
      )
      (downs): ModuleList(
        (0): DBlock(
          (down): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (conv): ResnetBlock(
            (double_conv): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): ChanRMSNorm()
              (2): SiLU()
              (3): Conv2d(32, 32, kernel_size=(3, 3), strid

In [12]:
attenuation_cfg = "configs/attenuation.yaml"
attenuation = "jnd_1_3"
attenuation_cfg = omegaconf.OmegaConf.load(attenuation_cfg)[attenuation]
attenuation = JND(**attenuation_cfg).to(device)
attenuation.preprocess = unnormalize_img
attenuation.postprocess = normalize_img

wam.attenuation = attenuation
wam.scaling_w = 1.0
wam.scaling_i = 1.0

In [13]:
vid = normalize_img(video_tensor / 255)
vid_w = wam.embed_inference(vid)

In [14]:
out_path = "output.mp4"

video_tensor_w = unnormalize_img(vid_w)
video_tensor_w = video_tensor_w.clamp(0, 1)
video_tensor_w = video_tensor_w.numpy()
video_tensor_w = 255 * np.transpose(video_tensor_w, (0, 2, 3, 1))

# save_vid
torchvision.io.write_video(out_path, video_tensor_w, fps=fps, video_codec='libx264', options={'crf': '21'})

# get video fps and durations
cap = cv2.VideoCapture(out_path)
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
duration = frame_count / fps
cap.release()
cap = cv2.VideoCapture(video_path)
ori_fps = cap.get(cv2.CAP_PROP_FPS)
ori_frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
ori_duration = ori_frame_count / ori_fps
cap.release()
print(f"Output video fps: {fps}, duration: {duration:.2f}s, Original video fps: {ori_fps}, duration: {ori_duration:.2f}s")

# get sizes
size = os.path.getsize(out_path) / 1e6
original_size = os.path.getsize(video_path) / 1e6
size_per_sec = size / duration
original_size_per_sec = original_size / ori_duration
print(f"Output video size: {size:.2f} MB, Original video size: {original_size:.2f} MB")
print(f"Output video size per sec: {size_per_sec:.2f} MB, Original video size per sec: {original_size_per_sec:.2f} MB")

Output video fps: 24.0, duration: 3.00s, Original video fps: 24.0, duration: 13.25s
Output video size: 4.53 MB, Original video size: 16.50 MB
Output video size per sec: 1.51 MB, Original video size per sec: 1.25 MB
