In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch

In [2]:
class Generator(nn.Module):
    def __init__(self, channels=3):
        super(Generator, self).__init__()

        def downsample(in_feat, out_feat, normalize=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        def upsample(in_feat, out_feat, target_size=None, normalize=True):
            layers = [nn.ConvTranspose2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if target_size is not None:
                layers.append(nn.Upsample(size=target_size, mode='bilinear', align_corners=False))
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.ReLU())
            return layers
        
        # 修改 self.model 的定义，在最后添加一层上采样以确保输出尺寸正确
        self.model = nn.Sequential(
            *downsample(channels, 64, normalize=False),
            *downsample(64, 64),
            *downsample(64, 128),
            *downsample(128, 256),
            *downsample(256, 512),
            nn.Conv2d(512, 4000, 1),
            *upsample(4000, 512),
            *upsample(512, 256),
            *upsample(256, 128),
            *upsample(128, 64),
            *upsample(64, channels, target_size=(128, 128)),  # 确保输出尺寸为 128x128
            nn.Tanh()
        )

    def forward(self, x, mask):
        # x 是原始输入图像（包含完整的图像信息）
        # mask 是掩码，指示哪些区域需要被修改
        generated_content = self.model(x)

        # 只更新掩码区域
        output = x * (1 - mask) + generated_content * mask
        return output

In [3]:
class Discriminator(nn.Module):
    def __init__(self, channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, stride, normalize):
            """Returns layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = channels
        for out_filters, stride, normalize in [(64, 2, False), (128, 2, True), (256, 2, True), (512, 1, True)]:
            layers.extend(discriminator_block(in_filters, out_filters, stride, normalize))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, 3, 1, 1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

In [4]:
import glob
import random
import os
import numpy as np

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms


class ImageDataset(Dataset):
    def __init__(self, root, category, transforms_=None, img_size=128, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.img_size = img_size
        self.mode = mode
        self.category = category
        # 构建原始图像和掩码的路径
        original_img_path = os.path.join(root, 'original_image', category)
        mask_path = os.path.join(root, 'mask', category)

        self.original_files = sorted(glob.glob(f"{original_img_path}/*.jpg"))
        self.mask_files = sorted(glob.glob(f"{mask_path}/*.jpg"))  
        self.transforms_mask = transforms.Compose([
            transforms.Resize((128, 128)),  # 确保掩码大小一致
            transforms.ToTensor(),  # 转为张量
        ])

        assert len(self.original_files) == len(self.mask_files), "The number of images and masks do not match!"
        

    # def apply_random_mask(self, img):
    #     """Randomly masks image"""
    #     y1, x1 = np.random.randint(0, self.img_size - self.mask_size, 2)
    #     y2, x2 = y1 + self.mask_size, x1 + self.mask_size
    #     masked_part = img[:, y1:y2, x1:x2]
    #     masked_img = img.clone()
    #     masked_img[:, y1:y2, x1:x2] = 1

    #     return masked_img, masked_part

    # def apply_center_mask(self, img):
    #     """Mask center part of image"""
    #     # Get upper-left pixel coordinate
    #     i = (self.img_size - self.mask_size) // 2
    #     masked_img = img.clone()
    #     masked_img[:, i : i + self.mask_size, i : i + self.mask_size] = 1

    #     return masked_img, i
    def apply_mask(self, original_img, mask_img):
        # 确保掩码是单通道且大小匹配原始图像
        if mask_img.shape[0] == 1:  # 如果是单通道
            mask_img = mask_img.expand(3, -1, -1)  # 扩展为 3 通道
        
        # 提取被遮挡部分的图像
        masked_part = original_img * mask_img
        
        # 创建被遮挡的图像，用1填充遮挡区域
        masked_img = original_img.clone()
        masked_img[mask_img == 1] = 1  # 用1替换遮挡区域
        
        if mask_img.shape[0] == 1:  # 如果是单通道
            mask_img = mask_img.expand(3, -1, -1)  # 扩展到 3 通道

        return masked_img, masked_part
    
    def __getitem__(self, index):
        # 处理原始图像
        img = Image.open(self.original_files[index]).convert("RGB")
        img = self.transform(img)
        
        # 处理掩码图像
        mask_img = Image.open(self.mask_files[index]).convert("L")  # 转为灰度图
        mask_img = self.transforms_mask(mask_img)
        
        # 确保掩码是二值化的
        mask_img = (mask_img > 0.5).float()  # 掩码值为 0 或 1
    
        return img, mask_img


    def __len__(self):
        return len(self.original_files)

In [5]:
class Opt:
    def __init__(self):
        self.n_epochs = 200
        self.batch_size = 8
        self.dataset_name = "img_align_celeba"
        self.lr = 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_cpu = 4
        self.latent_dim = 100
        self.img_size = 128
        self.mask_size = 64
        self.channels = 3
        self.sample_interval = 500

    def __repr__(self):
        return (f"Opt(n_epochs={self.n_epochs}, batch_size={self.batch_size}, "
                f"dataset_name='{self.dataset_name}', lr={self.lr}, b1={self.b1}, "
                f"b2={self.b2}, n_cpu={self.n_cpu}, latent_dim={self.latent_dim}, "
                f"img_size={self.img_size}, mask_size={self.mask_size}, "
                f"channels={self.channels}, sample_interval={self.sample_interval})")

In [6]:
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL import Image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable


import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

opt = Opt()


cuda = True if torch.cuda.is_available() else False

# Calculate output of image discriminator (PatchGAN)
patch_h, patch_w = 16, 16  # 固定为16，以匹配 discriminator 的输出
patch = (1, patch_h, patch_w)


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


# Loss function
adversarial_loss = torch.nn.MSELoss()
pixelwise_loss = torch.nn.L1Loss()

# Initialize generator and discriminator
generator = Generator(channels=opt.channels)
discriminator = Discriminator(channels=opt.channels)

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    pixelwise_loss.cuda()

# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Dataset loader
transforms_ = [
    transforms.Resize((opt.img_size, opt.img_size), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]



# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor



# ----------
#  Training
# ----------
def train_model(category):
    dataloader = DataLoader(
    ImageDataset("./dataset_train", category, transforms_=transforms_),
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.n_cpu,
    )
    test_dataloader = DataLoader(
        ImageDataset("./dataset_test", category, transforms_=transforms_, mode="val"),
        batch_size=12,
        shuffle=True,
        num_workers=1,
    )

    for epoch in range(opt.n_epochs):
        for i, (imgs, masks) in enumerate(dataloader):
    
            # Adversarial ground truths
            # Adversarial ground truths
            valid = Variable(Tensor(imgs.shape[0], 1, patch_h, patch_w).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(imgs.shape[0], 1, patch_h, patch_w).fill_(0.0), requires_grad=False)
            
            # Debugging: 打印出 valid 和 fake 的形状
    
             # Configure input
            imgs = Variable(imgs.type(Tensor))
            masks = Variable(masks.type(Tensor))
        
            # Create masked images (input for generator)
            masked_imgs = imgs * (1 - masks) + masks  # 用掩码生成损坏的图像
            # -----------------
            #  Train Generator
            # -----------------
    
            optimizer_G.zero_grad()
    
            # Generate a batch of images
            gen_imgs = generator(masked_imgs, masks)
    
            
    
            # Adversarial and pixelwise loss
            g_adv = adversarial_loss(discriminator(gen_imgs), valid)
            g_pixel = pixelwise_loss(gen_imgs * masks, imgs * masks)  # 只计算掩码区域的 pixel loss
            g_loss = 0.001 * g_adv + 0.999 * g_pixel
    
            g_loss.backward()
            optimizer_G.step()
    
            # ---------------------
            #  Train Discriminator
            # ---------------------
    
            optimizer_D.zero_grad()
    
            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(imgs * masks), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach() * masks), fake)
            d_loss = 0.5 * (real_loss + fake_loss)
        
            d_loss.backward()
            optimizer_D.step()
        
            # print(
            #     "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G adv: %f, pixel: %f]"
            #     % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_adv.item(), g_pixel.item())
            # )
        
            # Generate sample at sample interval
            batches_done = epoch * len(dataloader) + i
    # 每个 epoch 结束后保存模型
    model_path = f"{category}_generator.pth"
    torch.save(generator.state_dict(), model_path)
    print(f"Saved generator model to {model_path}")

    model_path = f"{category}_discriminator.pth"
    torch.save(discriminator.state_dict(), model_path)
    print(f"Saved discriminator model to {model_path}")


In [None]:
# 依次训练每种类型的模型
image_types = ["face", "scenario", "street_scene_pairs", "texture"]
for image_type in image_types:
    #train_model(image_type=image_type, num_epochs=10, batch_size=16, learning_rate=1e-5)
    train_model(image_type)

  valid = Variable(Tensor(imgs.shape[0], 1, patch_h, patch_w).fill_(1.0), requires_grad=False)


Saved generator model to face_generator.pth
Saved discriminator model to face_discriminator.pth
Saved generator model to scenario_generator.pth
Saved discriminator model to scenario_discriminator.pth
Saved generator model to street_scene_pairs_generator.pth
Saved discriminator model to street_scene_pairs_discriminator.pth
