In [1]:
import os
import random
import string
from PIL import Image, ImageDraw, ImageFont
# os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import torch
import torchvision  # Add this line
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torch.optim as optim
from tqdm import tqdm
import torch.nn.functional as F
import shutil
from torchvision.transforms.functional import to_pil_image

In [2]:
# 数据转换
transform = transforms.Compose([
    transforms.ToTensor()
])

# 自定义数据集
class WatermarkRemovalDataset(Dataset):
    def __init__(self, watermarked_dir, mask_dir, clean_dir, transform=None):
        self.watermarked_dir = watermarked_dir
        self.mask_dir = mask_dir
        self.clean_dir = clean_dir
        self.images = sorted(os.listdir(watermarked_dir))
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        watermarked_path = os.path.join(self.watermarked_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx])
        clean_path = os.path.join(self.clean_dir, self.images[idx])

        watermarked_img = Image.open(watermarked_path).convert("RGB")
        mask_img = Image.open(mask_path).convert("L")  # Mask is 单通道
        clean_img = Image.open(clean_path).convert("RGB")

        if self.transform:
            watermarked_img = self.transform(watermarked_img)
            mask_img = self.transform(mask_img)
            clean_img = self.transform(clean_img)

        # 返回四个值：watermarked_img, mask_img, clean_img, filename
        return watermarked_img, mask_img, clean_img, self.images[idx]

# 自定义 collate_fn
def custom_collate_fn(batch):
    watermarked_imgs, masks, clean_imgs, filenames = zip(*batch)
    return list(watermarked_imgs), list(masks), list(clean_imgs), list(filenames)

# 数据加载器
train_dataset = WatermarkRemovalDataset(
    watermarked_dir="./web_dataset_split/test1/watermarked",
    mask_dir="./web_test_mask_results",
    clean_dir="./web_dataset_split/test1/no_watermark",
    transform=transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=custom_collate_fn
)


In [3]:
import torch.nn as nn

class WatermarkRemovalModel(nn.Module):
    def __init__(self):
        super(WatermarkRemovalModel, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()  # 输出值归一化到 [0, 1]
        )

    def forward(self, watermarked_img, mask):
        x = torch.cat((watermarked_img, mask), dim=1)  # 将图像和mask在通道维度拼接
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [4]:
from tqdm import tqdm
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS = 10
LEARNING_RATE = 0.001

# 初始化模型、损失函数和优化器
model = WatermarkRemovalModel().to(DEVICE)
criterion = nn.MSELoss()  # Mean Squared Error Loss
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("Starting training...")
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0

    # 使用 tqdm 包装数据加载器以显示进度条
    with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{EPOCHS}") as progress_bar:
        for watermarked_imgs, masks, clean_imgs, file_names in progress_bar:
            batch_loss = 0  # 记录当前 batch 的损失
            for watermarked_img, mask, clean_img, file_name in zip(watermarked_imgs, masks, clean_imgs, file_names):
                watermarked_img = watermarked_img.to(DEVICE).unsqueeze(0)  # 添加 batch 维度
                mask = mask.to(DEVICE).unsqueeze(0)
                clean_img = clean_img.to(DEVICE).unsqueeze(0)

                optimizer.zero_grad()
                output = model(watermarked_img, mask)

                # 动态调整输出尺寸匹配 clean_img
                _, _, target_h, target_w = clean_img.shape
                output = F.interpolate(output, size=(target_h, target_w), mode='bilinear', align_corners=True)

                loss = criterion(output, clean_img)
                loss.backward()
                optimizer.step()

                batch_loss += loss.item()
                epoch_loss += loss.item()

            # 更新进度条显示当前 batch 的平均损失
            progress_bar.set_postfix({"Batch Loss": batch_loss / len(watermarked_imgs)})

    # 打印每个 epoch 的平均损失
    print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {epoch_loss / len(train_loader):.4f}")

# 保存模型
torch.save(model.state_dict(), "web_watermark_removal_model.pth")
print("Model saved.")


Starting training...


Epoch 1/10: 100%|██████████| 4000/4000 [01:27<00:00, 45.80it/s, Batch Loss=0.00172] 


Epoch 1/10, Loss: 0.0035


Epoch 2/10: 100%|██████████| 4000/4000 [00:52<00:00, 76.64it/s, Batch Loss=0.00157] 


Epoch 2/10, Loss: 0.0017


Epoch 3/10: 100%|██████████| 4000/4000 [00:52<00:00, 76.74it/s, Batch Loss=0.00132] 


Epoch 3/10, Loss: 0.0015


Epoch 4/10: 100%|██████████| 4000/4000 [00:52<00:00, 76.25it/s, Batch Loss=0.000674]


Epoch 4/10, Loss: 0.0014


Epoch 5/10: 100%|██████████| 4000/4000 [00:54<00:00, 72.84it/s, Batch Loss=0.00015] 


Epoch 5/10, Loss: 0.0013


Epoch 6/10: 100%|██████████| 4000/4000 [00:54<00:00, 73.62it/s, Batch Loss=0.000711]


Epoch 6/10, Loss: 0.0013


Epoch 7/10: 100%|██████████| 4000/4000 [00:53<00:00, 74.54it/s, Batch Loss=0.000571]


Epoch 7/10, Loss: 0.0013


Epoch 8/10: 100%|██████████| 4000/4000 [00:53<00:00, 74.93it/s, Batch Loss=0.000463]


Epoch 8/10, Loss: 0.0012


Epoch 9/10: 100%|██████████| 4000/4000 [00:53<00:00, 75.15it/s, Batch Loss=0.00451] 


Epoch 9/10, Loss: 0.0012


Epoch 10/10: 100%|██████████| 4000/4000 [00:53<00:00, 74.89it/s, Batch Loss=0.00138] 

Epoch 10/10, Loss: 0.0012
Model saved.





In [6]:
from torchvision.transforms.functional import to_pil_image
import os

def remove_watermark(test_loader, model_path, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    model = WatermarkRemovalModel().to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()

    with torch.no_grad():
        for watermarked_imgs, masks, clean_imgs ,filenames in test_loader:  # filenames 从数据集中获取
            for watermarked_img, mask, clean_img ,filename in zip(watermarked_imgs, masks, clean_imgs, filenames):
                watermarked_img = watermarked_img.to(DEVICE).unsqueeze(0)
                mask = mask.to(DEVICE).unsqueeze(0)

                output = model(watermarked_img, mask)

                # 动态调整输出尺寸
                original_size = watermarked_img.shape[2:]
                output = torch.nn.functional.interpolate(output, size=original_size, mode='bilinear', align_corners=True)

                output = output.squeeze(0).cpu()
                output_img = to_pil_image(output)

                # 保存文件，确保文件名和扩展名正确
                base_name, ext = os.path.splitext(filename)
                if ext.lower() not in [".jpg", ".jpeg", ".png"]:
                    ext = ".png"  # 默认扩展名
                output_path = os.path.join(output_dir, f"{base_name}{ext}")
                output_img.save(output_path)

                print(f"Saved: {output_path}")

    print("Watermark removal complete.")


In [7]:
test_dataset = WatermarkRemovalDataset(
    watermarked_dir="./web_dataset_split/test2/watermarked",
    mask_dir="./web_test_model_result/mask",
    clean_dir="./web_dataset_split/test2/no_watermark",
    transform=transform
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=custom_collate_fn
)

output_dir="./web_test_model_result/watermark_removal_output"

remove_watermark(test_loader, "web_watermark_removal_model.pth", output_dir)


  model.load_state_dict(torch.load(model_path, map_location=DEVICE))


Saved: ./web_test_model_result/watermark_removal_output\777.jpeg
Saved: ./web_test_model_result/watermark_removal_output\7770.jpeg
Saved: ./web_test_model_result/watermark_removal_output\7771.jpeg
Saved: ./web_test_model_result/watermark_removal_output\7772.jpeg
Saved: ./web_test_model_result/watermark_removal_output\7773.jpeg
Saved: ./web_test_model_result/watermark_removal_output\7774.jpeg
Saved: ./web_test_model_result/watermark_removal_output\7775.jpeg
Saved: ./web_test_model_result/watermark_removal_output\7776.jpeg
Saved: ./web_test_model_result/watermark_removal_output\7777.jpeg
Saved: ./web_test_model_result/watermark_removal_output\7778.jpeg
Saved: ./web_test_model_result/watermark_removal_output\7779.jpeg
Saved: ./web_test_model_result/watermark_removal_output\778.jpeg
Saved: ./web_test_model_result/watermark_removal_output\7780.jpeg
Saved: ./web_test_model_result/watermark_removal_output\7781.jpeg
Saved: ./web_test_model_result/watermark_removal_output\7782.jpeg
Saved: ./web

In [8]:
test_dataset = WatermarkRemovalDataset(
    watermarked_dir="./true_web_data/watermarked",
    mask_dir="./true_web_data/model2/predicted_mask",
    clean_dir="./true_web_data/no_watermark",
    transform=transform
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=custom_collate_fn
)

output_dir="./true_web_data/model2/output_no_watermark"

remove_watermark(test_loader, "web_watermark_removal_model.pth", output_dir)



  model.load_state_dict(torch.load(model_path, map_location=DEVICE))


Saved: ./true_web_data/model2/output_no_watermark\00A5QK.jpg
Saved: ./true_web_data/model2/output_no_watermark\00DT57.jpg
Saved: ./true_web_data/model2/output_no_watermark\01FAOU.jpg
Saved: ./true_web_data/model2/output_no_watermark\01IMVE.jpg
Saved: ./true_web_data/model2/output_no_watermark\01UZBZ.jpg
Saved: ./true_web_data/model2/output_no_watermark\02Z6TT.jpg
Saved: ./true_web_data/model2/output_no_watermark\03UINF.jpg
Saved: ./true_web_data/model2/output_no_watermark\03YTHF.jpg
Saved: ./true_web_data/model2/output_no_watermark\0AH3BR.jpg
Saved: ./true_web_data/model2/output_no_watermark\0B9T85.jpg
Saved: ./true_web_data/model2/output_no_watermark\0BJW9C.jpg
Saved: ./true_web_data/model2/output_no_watermark\0BWU6O.jpg
Saved: ./true_web_data/model2/output_no_watermark\0F6BBP.jpg
Saved: ./true_web_data/model2/output_no_watermark\0GAHD1.jpg
Saved: ./true_web_data/model2/output_no_watermark\0GP2FA.jpg
Saved: ./true_web_data/model2/output_no_watermark\0GR7CG.jpg
Saved: ./true_web_data/m