In [1]:
import torch
import glob
import random
import os
import numpy as np
import torch.nn as nn

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

In [2]:
class Generator(nn.Module):
    def __init__(self, channels=3):
        super(Generator, self).__init__()

        def downsample(in_feat, out_feat, normalize=True):
            layers = [nn.Conv2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2))
            return layers

        def upsample(in_feat, out_feat, target_size=None, normalize=True):
            layers = [nn.ConvTranspose2d(in_feat, out_feat, 4, stride=2, padding=1)]
            if target_size is not None:
                layers.append(nn.Upsample(size=target_size, mode='bilinear', align_corners=False))
            if normalize:
                layers.append(nn.BatchNorm2d(out_feat, 0.8))
            layers.append(nn.ReLU())
            return layers
        
        # 修改 self.model 的定义，在最后添加一层上采样以确保输出尺寸正确
        self.model = nn.Sequential(
            *downsample(channels, 64, normalize=False),
            *downsample(64, 64),
            *downsample(64, 128),
            *downsample(128, 256),
            *downsample(256, 512),
            nn.Conv2d(512, 4000, 1),
            *upsample(4000, 512),
            *upsample(512, 256),
            *upsample(256, 128),
            *upsample(128, 64),
            *upsample(64, channels, target_size=(128, 128)),  # 确保输出尺寸为 128x128
            nn.Tanh()
        )

    def forward(self, x, mask):
        # x 是原始输入图像（包含完整的图像信息）
        # mask 是掩码，指示哪些区域需要被修改
        generated_content = self.model(x)

        # 只更新掩码区域
        output = x * (1 - mask) + generated_content * mask
        return output

In [3]:
class Discriminator(nn.Module):
    def __init__(self, channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, stride, normalize):
            """Returns layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = channels
        for out_filters, stride, normalize in [(64, 2, False), (128, 2, True), (256, 2, True), (512, 1, True)]:
            layers.extend(discriminator_block(in_filters, out_filters, stride, normalize))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, 3, 1, 1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

In [4]:
import glob
import random
import os
import numpy as np

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms


class ImageDataset(Dataset):
    def __init__(self, root, category, transforms_=None, img_size=128, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.img_size = img_size
        self.mode = mode
        self.category = category
        # 构建原始图像和掩码的路径
        original_img_path = os.path.join(root, 'original_image', category)
        mask_path = os.path.join(root, 'mask', category)

        self.original_files = sorted(glob.glob(f"{original_img_path}/*.jpg"))
        self.mask_files = sorted(glob.glob(f"{mask_path}/*.jpg"))  
        self.transforms_mask = transforms.Compose([
            transforms.Resize((128, 128)),  # 确保掩码大小一致
            transforms.ToTensor(),  # 转为张量
        ])

        assert len(self.original_files) == len(self.mask_files), "The number of images and masks do not match!"
        

    # def apply_random_mask(self, img):
    #     """Randomly masks image"""
    #     y1, x1 = np.random.randint(0, self.img_size - self.mask_size, 2)
    #     y2, x2 = y1 + self.mask_size, x1 + self.mask_size
    #     masked_part = img[:, y1:y2, x1:x2]
    #     masked_img = img.clone()
    #     masked_img[:, y1:y2, x1:x2] = 1

    #     return masked_img, masked_part

    # def apply_center_mask(self, img):
    #     """Mask center part of image"""
    #     # Get upper-left pixel coordinate
    #     i = (self.img_size - self.mask_size) // 2
    #     masked_img = img.clone()
    #     masked_img[:, i : i + self.mask_size, i : i + self.mask_size] = 1

    #     return masked_img, i
    def apply_mask(self, original_img, mask_img):
        # 确保掩码是单通道且大小匹配原始图像
        if mask_img.shape[0] == 1:  # 如果是单通道
            mask_img = mask_img.expand(3, -1, -1)  # 扩展为 3 通道
        
        # 提取被遮挡部分的图像
        masked_part = original_img * mask_img
        
        # 创建被遮挡的图像，用1填充遮挡区域
        masked_img = original_img.clone()
        masked_img[mask_img == 1] = 1  # 用1替换遮挡区域
        
        if mask_img.shape[0] == 1:  # 如果是单通道
            mask_img = mask_img.expand(3, -1, -1)  # 扩展到 3 通道

        return masked_img, masked_part
    
    def __getitem__(self, index):
        # 处理原始图像
        img = Image.open(self.original_files[index]).convert("RGB")
        img = self.transform(img)
        
        # 处理掩码图像
        mask_img = Image.open(self.mask_files[index]).convert("L")  # 转为灰度图
        mask_img = self.transforms_mask(mask_img)
        
        # 确保掩码是二值化的
        mask_img = (mask_img > 0.5).float()  # 掩码值为 0 或 1
    
        return img, mask_img


    def __len__(self):
        return len(self.original_files)

In [5]:
class Opt:
    def __init__(self):
        self.n_epochs = 200
        self.batch_size = 8
        self.dataset_name = "img_align_celeba"
        self.lr = 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_cpu = 4
        self.latent_dim = 100
        self.img_size = 128
        self.mask_size = 64
        self.channels = 3
        self.sample_interval = 500

    def __repr__(self):
        return (f"Opt(n_epochs={self.n_epochs}, batch_size={self.batch_size}, "
                f"dataset_name='{self.dataset_name}', lr={self.lr}, b1={self.b1}, "
                f"b2={self.b2}, n_cpu={self.n_cpu}, latent_dim={self.latent_dim}, "
                f"img_size={self.img_size}, mask_size={self.mask_size}, "
                f"channels={self.channels}, sample_interval={self.sample_interval})")

In [6]:
import os
from torchvision.utils import save_image
import json
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
from sewar.full_ref import vifp  # 第三方库 sewar 提供 VIF 和 FSIM 的实现
import numpy as np

from skimage.metrics import structural_similarity as ssim

def calculate_metrics(original, generated):
    """
    计算图像质量评估指标
    :param original: 原始图像 (H, W, C) 格式
    :param generated: 生成图像 (H, W, C) 格式
    :return: 指标字典
    """
    
   
    
    # # 确保图像被标准化到 [0, 1]
    original = np.array(original).astype(np.float32) / 255.0
    generated = np.array(generated).astype(np.float32) / 255.0    



    # 如果图像只有单通道，增加通道维度
    if original.ndim == 2:
        original = np.expand_dims(original, axis=-1)
    if generated.ndim == 2:
        generated = np.expand_dims(generated, axis=-1)
    
    # PSNR
    psnr_value = psnr(original, generated, data_range=1)

     # SSIM 
    ssim_value, _ = ssim(
    original, 
    generated, 
    channel_axis=-1,  # 指定通道轴为最后一维
    data_range=1.0, 
    full=True, 
    win_size=7
)

    
    # VIF
    vif_value = vifp(original, generated)

    return {
        "PSNR": psnr_value,
        "SSIM": ssim_value,
        "VIF": vif_value,
    }
    
# 确保输出目录存在
def adjust_image_size(image, target_size):
    """
    调整图像大小以匹配目标尺寸。
    :param image: 输入的PIL图像对象
    :param target_size: 目标尺寸 (宽度, 高度)
    :return: 调整大小后的PIL图像对象
    """
    return image.resize(target_size, Image.ANTIALIAS)
    
opt = Opt()

cuda = True if torch.cuda.is_available() else False

def test_model(category, generator_path, output_dir = "./output_images_gan"):
    # 加载训练好的生成器模型
    output_dir += "/" + category
    os.makedirs(output_dir, exist_ok=True)
    generator = Generator(channels=opt.channels)
    generator.load_state_dict(torch.load(generator_path))
    generator.eval()
    
    if cuda:
        generator.cuda()

    # 准备测试数据集加载器
    transforms_ = [
        transforms.Resize((opt.img_size, opt.img_size), Image.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
    test_dataloader = DataLoader(
        ImageDataset("./dataset_test", category, transforms_=transforms_, mode="val"),
        batch_size=12,
        shuffle=False,  # 测试时无需打乱数据
        num_workers=1,
    )

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    metrics_sum = {"PSNR": 0.0, "SSIM": 0.0, "VIF": 0.0}
    total_images = 0
    

    for i, (imgs, masks) in enumerate(test_dataloader):
        imgs = Variable(imgs.type(Tensor))
        masks = Variable(masks.type(Tensor))

        # 创建掩码图像（生成器输入）
        masked_imgs = imgs * (1 - masks) + masks

        # 使用生成器预测修复后的图像
        gen_imgs = generator(masked_imgs, masks)

        # 保存生成的图像
        for j in range(gen_imgs.shape[0]):
            img_sample = gen_imgs[j].data
            save_path = f"{output_dir}/output_{category}_{i * len(gen_imgs) + j}.png"
            save_image(img_sample, save_path)

            # 加载保存的生成图像和真实图像
            generated_image = Image.open(save_path).convert("RGB")
            original_image = (imgs[j].cpu().numpy().transpose(1, 2, 0) * 0.5 + 0.5) * 255
            original_image = Image.fromarray(original_image.astype(np.uint8))

            # 如果需要，调整输出图像大小以匹配原始图像
            if generated_image.size != original_image.size:
                original_image = adjust_image_size(original_image, generated_image.size)

            # 计算指标
            metrics = calculate_metrics(original_image, generated_image)
            for key in metrics:
                metrics_sum[key] += metrics[key]
            total_images += 1
            
     # 计算平均指标
    metrics_avg = {key: metrics_sum[key] / total_images for key in metrics_sum}

    # 保存指标到文件
    metrics_file = os.path.join(output_dir, f"{category}_average_metrics_gan.json")
    with open(metrics_file, "w") as f:
        json.dump(metrics_avg, f, indent=4)

    print(f"Metrics for {category}: {metrics_avg}")
    print(f"Test completed and images saved to {output_dir}")

    print(f"Test completed and images saved to {output_dir}")

# 假设我们只对 'face' 类型进行测试
image_types = ["face", "scenario", "street_scene_pairs", "texture"]
for image_type in image_types:
    #train_model(image_type=image_type, num_epochs=10, batch_size=16, learning_rate=1e-5)
    generator_path = f"{image_type}_generator.pth"
    test_model(image_type, generator_path)

  generator.load_state_dict(torch.load(generator_path))


Metrics for face: {'PSNR': 10.96069994132162, 'SSIM': 0.3195035795960575, 'VIF': 0.8215660663227156}
Test completed and images saved to ./output_images_gan/face
Test completed and images saved to ./output_images_gan/face
Metrics for scenario: {'PSNR': 10.267389511204204, 'SSIM': 0.2773144560400397, 'VIF': 0.7742339353464112}
Test completed and images saved to ./output_images_gan/scenario
Test completed and images saved to ./output_images_gan/scenario
Metrics for street_scene_pairs: {'PSNR': 11.085944655448026, 'SSIM': 0.2836719263705891, 'VIF': 0.7692075080523576}
Test completed and images saved to ./output_images_gan/street_scene_pairs
Test completed and images saved to ./output_images_gan/street_scene_pairs
Metrics for texture: {'PSNR': 9.46899983580659, 'SSIM': 0.26791315180948005, 'VIF': 0.840434243716509}
Test completed and images saved to ./output_images_gan/texture
Test completed and images saved to ./output_images_gan/texture
