In [2]:
import os
import random
import string
from PIL import Image, ImageDraw, ImageFont
# os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import torch
import torchvision  # Add this line
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm
import shutil
# Debugging information
print(torch.__version__)
print(torchvision.__version__)  # Now this will work
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Current device: {torch.cuda.current_device() if torch.cuda.is_available() else 'CPU'}")
if torch.cuda.is_available():
    print(f"Device name: {torch.cuda.get_device_name(0)}")


2.5.1+cu118
0.20.1+cu118
PyTorch version: 2.5.1+cu118
CUDA available: True
Current device: 0
Device name: NVIDIA GeForce RTX 3050 Laptop GPU


In [81]:
import os
import shutil

# Paths
BASE_DIR = "./dataset"
NO_WATERMARK_DIR = os.path.join(BASE_DIR, "no_watermark")
WATERMARKED_DIR = os.path.join(BASE_DIR, "watermarked")
MASKS_DIR = os.path.join(BASE_DIR, "masks")

OUTPUT_DIR = "./dataset_split"
TRAIN_DIR = os.path.join(OUTPUT_DIR, "train")
TEST_DIR = os.path.join(OUTPUT_DIR, "test")

# Create output directories
os.makedirs(os.path.join(TRAIN_DIR, "no_watermark"), exist_ok=True)
os.makedirs(os.path.join(TRAIN_DIR, "watermarked"), exist_ok=True)
os.makedirs(os.path.join(TRAIN_DIR, "masks"), exist_ok=True)

os.makedirs(os.path.join(TEST_DIR, "watermarked"), exist_ok=True)
os.makedirs(os.path.join(TEST_DIR, "masks"), exist_ok=True)
os.makedirs(os.path.join(TEST_DIR, "no_watermark"), exist_ok=True)

# File names
file_names = sorted(os.listdir(NO_WATERMARK_DIR))
total_files = len(file_names)

# Split sizes
TRAIN_SPLIT = int(0.8 * total_files)  # 80% for training
TEST_SPLIT = total_files - TRAIN_SPLIT  # Remaining 20% for testing

train_files = file_names[:TRAIN_SPLIT]
test_files = file_names[TRAIN_SPLIT:]

# Function to copy files
def copy_files(file_list, src_dir, dest_dir):
    for file_name in file_list:
        src_path = os.path.join(src_dir, file_name)
        dest_path = os.path.join(dest_dir, file_name)
        if os.path.exists(src_path):
            shutil.copy(src_path, dest_path)

# Copy training files
print("Copying training files...")
copy_files(train_files, NO_WATERMARK_DIR, os.path.join(TRAIN_DIR, "no_watermark"))
copy_files(train_files, WATERMARKED_DIR, os.path.join(TRAIN_DIR, "watermarked"))
copy_files(train_files, MASKS_DIR, os.path.join(TRAIN_DIR, "masks"))

# Copy testing files (only watermarked and masks)
print("Copying testing files...")
copy_files(test_files, WATERMARKED_DIR, os.path.join(TEST_DIR, "watermarked"))
copy_files(test_files, MASKS_DIR, os.path.join(TEST_DIR, "masks"))
copy_files(test_files, NO_WATERMARK_DIR, os.path.join(TEST_DIR, "no_watermark"))

# Summary
print(f"Dataset split complete.")
print(f"Training files: {len(train_files)}")
print(f"Testing files: {len(test_files)}")


Copying training files...
Copying testing files...
Dataset split complete.
Training files: 9981
Testing files: 2496


In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
from tqdm import tqdm

# Set device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Hyperparameters
EPOCHS = 10
LEARNING_RATE = 0.0001
WINDOW_SIZE = 64
STRIDE = 32

# Dataset paths
TRAIN_DIR = "./dataset_split/train"
TEST_DIR = "./dataset_split/test"
MASK_OUTPUT_DIR = "./test_mask_results"
os.makedirs(MASK_OUTPUT_DIR, exist_ok=True)

# Data transformations (no resizing to retain original sizes)
transform = transforms.Compose([
    transforms.ToTensor()  # Convert to tensor while retaining original size
])

# Custom Dataset
class WatermarkSlidingWindowDataset(Dataset):
    def __init__(self, watermark_dir, mask_dir, transform=None):
        self.watermark_dir = watermark_dir
        self.mask_dir = mask_dir
        self.images = sorted(os.listdir(watermark_dir))
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        watermark_path = os.path.join(self.watermark_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])

        watermark_img = Image.open(watermark_path).convert("RGB")
        mask_img = Image.open(mask_path).convert("L")  # Mask is single-channel

        original_size = watermark_img.size  # Save original image size (width, height)
        filename = self.images[idx]  # Save the filename for reference

        if self.transform:
            watermark_img = self.transform(watermark_img)
            mask_img = self.transform(mask_img)

        return watermark_img, mask_img, original_size, filename


# Custom collate_fn to handle varying image sizes
def collate_fn(batch):
    return batch

# Load training and testing datasets
train_dataset = WatermarkSlidingWindowDataset(
    watermark_dir=os.path.join(TRAIN_DIR, "watermarked"),
    mask_dir=os.path.join(TRAIN_DIR, "masks"),
    transform=transform
)

test_dataset = WatermarkSlidingWindowDataset(
    watermark_dir=os.path.join(TEST_DIR, "watermarked"),
    mask_dir=os.path.join(TEST_DIR, "masks"),
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

# Sliding Window Feature Extractor
class SlidingWindowFeatureExtractor(nn.Module):
    def __init__(self, window_size=64, stride=32):
        super(SlidingWindowFeatureExtractor, self).__init__()
        self.window_size = window_size
        self.stride = stride
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=1),  # Single channel output
            nn.Sigmoid()  # Mask values in range [0, 1]
        )

    def forward(self, x):
        return self.encoder(x)

# Initialize model, optimizer, and loss function
model = SlidingWindowFeatureExtractor(window_size=WINDOW_SIZE, stride=STRIDE).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCELoss()  # Binary Cross-Entropy Loss

# Training loop
print("Starting training...")
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        watermarked_imgs, masks, original_sizes, filenames = batch[0]
        watermarked_imgs = watermarked_imgs.to(DEVICE).unsqueeze(0)  # Add batch dimension
        masks = masks.to(DEVICE).unsqueeze(0)

        optimizer.zero_grad()
        mask_pred = model(watermarked_imgs)
        # Resize predicted mask to target size
        _, _, target_h, target_w = masks.shape
        mask_pred = torch.nn.functional.interpolate(mask_pred, size=(target_h, target_w), mode='bilinear', align_corners=True)
        loss = criterion(mask_pred, masks)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {epoch_loss / len(train_loader):.4f}")

# Save the model
FEATURE_EXTRACTOR_PATH = "feature_extractor.pth"
torch.save(model.state_dict(), FEATURE_EXTRACTOR_PATH)
print("Feature Extractor saved.")




Using device: cuda
Starting training...


Epoch 1/10: 100%|██████████| 9981/9981 [07:03<00:00, 23.58it/s]


Epoch 1/10, Loss: 0.0736


Epoch 2/10: 100%|██████████| 9981/9981 [07:01<00:00, 23.68it/s]


Epoch 2/10, Loss: 0.0589


Epoch 3/10: 100%|██████████| 9981/9981 [07:00<00:00, 23.72it/s]


Epoch 3/10, Loss: 0.0532


Epoch 4/10: 100%|██████████| 9981/9981 [07:00<00:00, 23.75it/s]


Epoch 4/10, Loss: 0.0501


Epoch 5/10: 100%|██████████| 9981/9981 [07:01<00:00, 23.68it/s]


Epoch 5/10, Loss: 0.0482


Epoch 6/10: 100%|██████████| 9981/9981 [07:01<00:00, 23.67it/s]


Epoch 6/10, Loss: 0.0468


Epoch 7/10: 100%|██████████| 9981/9981 [07:00<00:00, 23.76it/s]


Epoch 7/10, Loss: 0.0458


Epoch 8/10: 100%|██████████| 9981/9981 [06:59<00:00, 23.82it/s]


Epoch 8/10, Loss: 0.0451


Epoch 9/10: 100%|██████████| 9981/9981 [06:59<00:00, 23.78it/s]


Epoch 9/10, Loss: 0.0443


Epoch 10/10: 100%|██████████| 9981/9981 [06:57<00:00, 23.91it/s]

Epoch 10/10, Loss: 0.0437
Feature Extractor saved.





In [6]:
import torch
print(torch.__version__)  # 查看 PyTorch 版本
print(torch.version.cuda)  # 查看 PyTorch 編譯時使用的 CUDA 版本
print(torch.cuda.is_available())  # 應該返回 True


2.5.1+cu118
11.8
True


In [33]:
import numpy
print(numpy.__version__)

2.1.3


In [97]:
def test_feature_extractor(test_loader, model_path):
    print("Testing feature extractor...")
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()

    with torch.no_grad():
        for i, batch in enumerate(tqdm(test_loader)):
            # Extract watermarked image, mask, original size, and filename
            watermarked_imgs, masks, original_sizes, filenames = batch[0]

            # Debug: Print the filename being processed
            # print(f"Processing image: {filenames}")

            watermarked_imgs = watermarked_imgs.to(DEVICE).unsqueeze(0)
            mask_pred = model(watermarked_imgs)

            # Resize predicted mask to original size
            original_size = original_sizes  # (width, height)
            filename = filenames

            # Extract base name and ensure valid file extension
            base_name, ext = os.path.splitext(filename)
            if ext.lower() not in [".jpg", ".jpeg", ".png"]:  # Handle invalid extensions
                ext = ".png"  # Default to PNG if extension is missing or invalid

            mask_pred_img = transforms.ToPILImage()(mask_pred.squeeze(0).cpu())
            mask_pred_resized = mask_pred_img.resize(original_size, Image.BILINEAR)

            # Save predicted mask
            output_file = os.path.join(MASK_OUTPUT_DIR, f"predicted_mask_{base_name}{ext}")
            mask_pred_resized.save(output_file)

            # print(f"Saved mask: {output_file}")

    print("All masks generated and saved.")

# Test the feature extractor
test_feature_extractor(test_loader, FEATURE_EXTRACTOR_PATH)


Testing feature extractor...


  model.load_state_dict(torch.load(model_path, map_location=DEVICE))


RuntimeError: Error(s) in loading state_dict for WatermarkRemovalModel:
	Missing key(s) in state_dict: "decoder.0.weight", "decoder.0.bias", "decoder.2.weight", "decoder.2.bias". 
	Unexpected key(s) in state_dict: "encoder.5.weight", "encoder.5.bias", "encoder.7.weight", "encoder.7.bias", "encoder.9.weight", "encoder.9.bias". 
	size mismatch for encoder.0.weight: copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 4, 3, 3]).

In [64]:
import torch
import os
import numpy as np
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import cv2

# Load the trained model
model = SlidingWindowFeatureExtractor(window_size=64, stride=32).to(DEVICE)
model.load_state_dict(torch.load("feature_extractor.pth", map_location=DEVICE))
model.eval()

def generate_clean_images(data_loader, model, output_dir="./clean_images"):
    """
    Generate de-watermarked images using the trained model and save them.
    :param data_loader: DataLoader that loads the data
    :param model: Trained Feature Extractor model
    :param output_dir: Path to output de-watermarked images
    """
    os.makedirs(output_dir, exist_ok=True)

    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            # Unpack data
            watermarked_imgs, masks, original_sizes, filenames = batch[0]
            watermarked_imgs = watermarked_imgs.to(DEVICE).unsqueeze(0)

            # Use the model to generate mask
            predicted_mask = model(watermarked_imgs)

            # Resize mask to match original image
            original_size = original_sizes
            predicted_mask_resized = torch.nn.functional.interpolate(
                predicted_mask, size=original_size[::-1], mode="bilinear", align_corners=True
            )

            # Convert tensors to NumPy arrays
            watermarked_img_np = watermarked_imgs.squeeze(0).cpu().numpy()
            predicted_mask_np = predicted_mask_resized.squeeze(0).cpu().numpy()

            # Transpose the image from (C, H, W) to (H, W, C)
            watermarked_img_np = np.transpose(watermarked_img_np, (1, 2, 0))

            # Convert predicted mask to binary mask (uint8)
            mask_binary = (predicted_mask_np.squeeze(0) > 0.5).astype(np.uint8) * 255

            # Convert image to uint8 format
            watermarked_img_uint8 = (watermarked_img_np * 255).astype(np.uint8)

            # Use OpenCV inpainting
            inpainted_img = cv2.inpaint(watermarked_img_uint8, mask_binary, 3, cv2.INPAINT_TELEA)

            # Convert back to PIL image
            inpainted_img_pil = Image.fromarray(inpainted_img)

            output_path = os.path.join(output_dir, filenames)
            inpainted_img_pil.save(output_path)

            print(f"Saved clean image: {output_path}")

# Rest of your code remains the same
train_loader = DataLoader(
    WatermarkSlidingWindowDataset(
        watermark_dir=os.path.join(TRAIN_DIR, "watermarked"),
        mask_dir=os.path.join(TRAIN_DIR, "masks"),
        transform=transform
    ),
    batch_size=1, shuffle=False, collate_fn=collate_fn
)

# Define output directory
CLEAN_IMAGE_OUTPUT_DIR = "./training_clean_images"

# Call the inference function to generate de-watermarked images
generate_clean_images(train_loader, model, output_dir=CLEAN_IMAGE_OUTPUT_DIR)


  model.load_state_dict(torch.load("feature_extractor.pth", map_location=DEVICE))


Saved clean image: ./training_clean_images\1.jpg
Saved clean image: ./training_clean_images\10.jpg
Saved clean image: ./training_clean_images\100.jpeg
Saved clean image: ./training_clean_images\1000.jpeg
Saved clean image: ./training_clean_images\10000.jpeg
Saved clean image: ./training_clean_images\10001.jpeg
Saved clean image: ./training_clean_images\10002.jpeg
Saved clean image: ./training_clean_images\10003.jpeg
Saved clean image: ./training_clean_images\10004.jpeg
Saved clean image: ./training_clean_images\10005.jpeg
Saved clean image: ./training_clean_images\10006.jpeg
Saved clean image: ./training_clean_images\10007.jpeg
Saved clean image: ./training_clean_images\10008.jpeg
Saved clean image: ./training_clean_images\10009.jpeg
Saved clean image: ./training_clean_images\1001.jpeg
Saved clean image: ./training_clean_images\10010.jpeg
Saved clean image: ./training_clean_images\10011.jpeg
Saved clean image: ./training_clean_images\10012.jpeg
Saved clean image: ./training_clean_image

In [72]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os

# 数据转换
transform = transforms.Compose([
    transforms.ToTensor()
])

# 自定义数据集
class WatermarkRemovalDataset(Dataset):
    def __init__(self, watermarked_dir, mask_dir, clean_dir, transform=None):
        self.watermarked_dir = watermarked_dir
        self.mask_dir = mask_dir
        self.clean_dir = clean_dir
        self.images = sorted(os.listdir(watermarked_dir))
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        watermarked_path = os.path.join(self.watermarked_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])
        clean_path = os.path.join(self.clean_dir, self.images[idx])

        watermarked_img = Image.open(watermarked_path).convert("RGB")
        mask_img = Image.open(mask_path).convert("L")  # Mask is单通道
        clean_img = Image.open(clean_path).convert("RGB")

        if self.transform:
            watermarked_img = self.transform(watermarked_img)
            mask_img = self.transform(mask_img)
            clean_img = self.transform(clean_img)

        return watermarked_img, mask_img, clean_img

# 自定义 collate_fn
def custom_collate_fn(batch):
    watermarked_imgs, masks, clean_imgs = zip(*batch)
    return list(watermarked_imgs), list(masks), list(clean_imgs)

# 数据加载器
train_dataset = WatermarkRemovalDataset(
    watermarked_dir="./dataset_split/train/watermarked",
    mask_dir="./dataset_split/train/masks",
    clean_dir="./dataset_split/train/no_watermark",
    transform=transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=custom_collate_fn
)


In [73]:
import torch.nn as nn

class WatermarkRemovalModel(nn.Module):
    def __init__(self):
        super(WatermarkRemovalModel, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()  # 输出值归一化到 [0, 1]
        )

    def forward(self, watermarked_img, mask):
        x = torch.cat((watermarked_img, mask), dim=1)  # 将图像和mask在通道维度拼接
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [74]:
from tqdm import tqdm
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 10
LEARNING_RATE = 0.001

# 初始化模型、损失函数和优化器
model = WatermarkRemovalModel().to(DEVICE)
criterion = nn.MSELoss()  # Mean Squared Error Loss
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("Starting training...")
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0

    # 使用 tqdm 包装数据加载器以显示进度条
    with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{EPOCHS}") as progress_bar:
        for watermarked_imgs, masks, clean_imgs in progress_bar:
            batch_loss = 0  # 记录当前 batch 的损失
            for watermarked_img, mask, clean_img in zip(watermarked_imgs, masks, clean_imgs):
                watermarked_img = watermarked_img.to(DEVICE).unsqueeze(0)  # 添加 batch 维度
                mask = mask.to(DEVICE).unsqueeze(0)
                clean_img = clean_img.to(DEVICE).unsqueeze(0)

                optimizer.zero_grad()
                output = model(watermarked_img, mask)

                # 动态调整输出尺寸匹配 clean_img
                _, _, target_h, target_w = clean_img.shape
                output = F.interpolate(output, size=(target_h, target_w), mode='bilinear', align_corners=True)

                loss = criterion(output, clean_img)
                loss.backward()
                optimizer.step()

                batch_loss += loss.item()
                epoch_loss += loss.item()

            # 更新进度条显示当前 batch 的平均损失
            progress_bar.set_postfix({"Batch Loss": batch_loss / len(watermarked_imgs)})

    # 打印每个 epoch 的平均损失
    print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {epoch_loss / len(train_loader):.4f}")

# 保存模型
torch.save(model.state_dict(), "watermark_removal_model.pth")
print("Model saved.")


Starting training...


Epoch 1/10: 100%|██████████| 9981/9981 [12:55<00:00, 12.88it/s, Batch Loss=0.000672]


Epoch 1/10, Loss: 0.0018


Epoch 2/10: 100%|██████████| 9981/9981 [12:54<00:00, 12.88it/s, Batch Loss=0.000608]


Epoch 2/10, Loss: 0.0009


Epoch 3/10: 100%|██████████| 9981/9981 [12:54<00:00, 12.89it/s, Batch Loss=0.000203]


Epoch 3/10, Loss: 0.0008


Epoch 4/10: 100%|██████████| 9981/9981 [12:53<00:00, 12.91it/s, Batch Loss=0.00127] 


Epoch 4/10, Loss: 0.0008


Epoch 5/10: 100%|██████████| 9981/9981 [12:56<00:00, 12.85it/s, Batch Loss=0.00239] 


Epoch 5/10, Loss: 0.0008


Epoch 6/10: 100%|██████████| 9981/9981 [12:56<00:00, 12.85it/s, Batch Loss=0.000878]


Epoch 6/10, Loss: 0.0007


Epoch 7/10: 100%|██████████| 9981/9981 [12:56<00:00, 12.86it/s, Batch Loss=0.000686]


Epoch 7/10, Loss: 0.0007


Epoch 8/10: 100%|██████████| 9981/9981 [12:57<00:00, 12.83it/s, Batch Loss=0.000282]


Epoch 8/10, Loss: 0.0007


Epoch 9/10: 100%|██████████| 9981/9981 [12:53<00:00, 12.90it/s, Batch Loss=0.00191] 


Epoch 9/10, Loss: 0.0007


Epoch 10/10: 100%|██████████| 9981/9981 [12:57<00:00, 12.83it/s, Batch Loss=0.0011]  

Epoch 10/10, Loss: 0.0007
Model saved.





In [98]:
from torchvision.transforms.functional import to_pil_image

def remove_watermark(test_loader, model_path, output_dir="./output"):
    os.makedirs(output_dir, exist_ok=True)
    model = WatermarkRemovalModel().to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()

    with torch.no_grad():
        for i, (watermarked_imgs, masks, _) in enumerate(test_loader):
            for watermarked_img, mask in zip(watermarked_imgs, masks):
                watermarked_img = watermarked_img.to(DEVICE).unsqueeze(0)
                mask = mask.to(DEVICE).unsqueeze(0)

                output = model(watermarked_img, mask)

                # 动态调整输出尺寸以匹配原始图像
                original_size = watermarked_img.shape[2:]  # (height, width)
                output = torch.nn.functional.interpolate(output, size=original_size, mode='bilinear', align_corners=True)

                output = output.squeeze(0).cpu()
                output_img = to_pil_image(output)

                output_path = os.path.join(output_dir, f"result_{i + 1}.png")
                output_img.save(output_path)

    print("Watermark removal complete.")


In [None]:
test_dataset = WatermarkRemovalDataset(
    watermarked_dir="./dataset_split/test/watermarked",
    mask_dir="./dataset_split/test/masks",
    clean_dir="./dataset_split/test/no_watermark",
    transform=transform
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=custom_collate_fn
)

remove_watermark(test_loader, "watermark_removal_model.pth")


  model.load_state_dict(torch.load(model_path, map_location=DEVICE))


KeyboardInterrupt: 