In [1]:
import argparse

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from PIL import Image
from torchvision import transforms as T


preprocessing = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


def parse_output(water_sym, clear_sym):
    return f"{'watermark' if water_sym > clear_sym else 'clear'}\n{water_sym:.3f}%w {clear_sym:.3f}%c"


if __name__ == '__main__':
    model = timm.create_model(
        'efficientnet_b3a', pretrained=True, num_classes=2)

    model.classifier = nn.Sequential(
        # 1536 is the orginal in_features
        nn.Linear(in_features=1536, out_features=625),
        nn.ReLU(),  # ReLu to be the activation function
        nn.Dropout(p=0.3),
        nn.Linear(in_features=625, out_features=256),
        nn.ReLU(),
        nn.Linear(in_features=256, out_features=2),
    )

    state_dict = torch.load('models/watermark_model_v1.pt')

    model.load_state_dict(state_dict)
    model.eval()

    if torch.cuda.is_available():
        model.cuda()
    
    watermark_im = preprocessing(Image.open('./images/watermark_example.png').convert('RGB'))
    clear_im = preprocessing(Image.open('./images/clear_example.png').convert('RGB'))

    batch = torch.stack([watermark_im, clear_im])
    batch=batch.cuda()

    with torch.no_grad():
        pred = model(batch)
        syms = F.softmax(pred, dim=1).detach().cpu().numpy().tolist()
        print(syms)
        for sym in syms:
            water_sym, clear_sym = sym
            if water_sym > clear_sym:
                print()
                # watermark
                pass
            else:
                # clear
                pass
            parse_output(water_sym, clear_sym)


  from .autonotebook import tqdm as notebook_tqdm
  model = create_fn(
  state_dict = torch.load('models/watermark_model_v1.pt')


[[0.9355467557907104, 0.06445320695638657], [0.20668235421180725, 0.7933176159858704]]



In [None]:
import torch
from PIL import Image
from pathlib import Path
from img_filter import ImgFilterArgs, build_filter
# Assuming the previous code is already imported and available

# Define the images and their corresponding prompts
to_test_images = [
    (
        "/jfs/jinjie/code/Pollux/apps/preprocessing/LAION-5B-WatermarkDetection/images/clear_example.png",
        "a football player",
    ),
    (
        "/jfs/jinjie/code/Pollux/apps/preprocessing/LAION-5B-WatermarkDetection/images/watermark_example.png",
        "a family walking on the street",
    ),
]

watermark_filter_args = ImgFilterArgs(
    model_name="WaterMarkFilter",
    pretrained_model_name_or_path="/jfs/checkpoints/data_preprocessing/watermark_model_v1.pt",
)

# Define the clip filter args with the model path
clip_filter_args = ImgFilterArgs(
    model_name="CLIPFilter",
    pretrained_model_name_or_path="openai/clip-vit-base-patch16",
)


# Function to load the image
def load_image(image_path: str) -> Image.Image:
    return Image.open(image_path).convert("RGB")


# Test the filters
def test_filters():
    for image_path, prompt in to_test_images:
        print(f"Testing image: {image_path}")

        # Load the image
        image = load_image(image_path)

        # Test WaterMarkFilter
        print("\nTesting WaterMarkFilter...")
        watermark_args = ImgFilterArgs(model_name="WaterMarkFilter")
        watermark_filter = build_filter(watermark_args)
        watermark_score = watermark_filter.predict(image)
        print(f"Watermark score (clear probability): {watermark_score:.4f}")

        # Test CLIPFilter
        print("\nTesting CLIPFilter...")
        clip_args = ImgFilterArgs(model_name="CLIPFilter")
        clip_filter = build_filter(clip_args)
        clip_score = clip_filter.predict(image, prompt=prompt)
        print(f"CLIP score: {clip_score:.4f}")


# Run the test
test_filters()