In [1]:
import os
from torchvision import transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from diffusers import StableDiffusionInpaintPipeline
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Generator
import gc
# 加载预训练模型
pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "./stable-diffusion-2-inpainting",
    torch_dtype=torch.float16
)
pipe.to("cuda")

2025-01-10 22:57:31.092088: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1736521051.224286    4392 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1736521051.260204    4392 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-10 22:57:31.594156: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

StableDiffusionInpaintPipeline {
  "_class_name": "StableDiffusionInpaintPipeline",
  "_diffusers_version": "0.31.0",
  "_name_or_path": "./stable-diffusion-2-inpainting",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "image_encoder": [
    null,
    null
  ],
  "requires_safety_checker": false,
  "safety_checker": [
    null,
    null
  ],
  "scheduler": [
    "diffusers",
    "PNDMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [2]:
class InpaintingDataset(Dataset):
    def __init__(self, original_image_dir, mask_dir, transform_img=None, transform_mask=None):
        self.original_image_dir = original_image_dir
        self.mask_dir = mask_dir
        self.transform_img = transform_img
        self.transform_mask = transform_mask

        # 获取原始图像和掩码图像的路径
        self.image_paths = []
        for subdir in os.listdir(original_image_dir):
            image_subdir = os.path.join(original_image_dir, subdir)
            mask_subdir = os.path.join(mask_dir, subdir)          
            self.image_paths.append((image_subdir, mask_subdir))
        print(f"Loaded {len(self.image_paths)} image-mask pairs")


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path, mask_path = self.image_paths[idx]

        # 读取原始图像和掩码图像
        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("1")  # mask是单通道图像

        # 应用数据增强和预处理
        if self.transform_img:
            img = self.transform_img(img)
        if self.transform_mask:
            mask = self.transform_mask(mask)
        
        return img, mask

transform_img = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
])

# 为掩码图像定义数据增强（不进行归一化）
transform_mask = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [3]:
import matplotlib.pyplot as plt
from PIL import Image

from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
from sewar.full_ref import vifp  # 第三方库 sewar 提供 VIF 和 FSIM 的实现
import numpy as np
import json
from skimage.metrics import structural_similarity as ssim

def calculate_metrics(original, generated):
    """
    计算图像质量评估指标
    :param original: 原始图像 (H, W, C) 格式
    :param generated: 生成图像 (H, W, C) 格式
    :return: 指标字典
    """
    
   
    
    # # 确保图像被标准化到 [0, 1]
    original = np.array(original).astype(np.float32) / 255.0
    generated = np.array(generated).astype(np.float32) / 255.0    



    # 如果图像只有单通道，增加通道维度
    if original.ndim == 2:
        original = np.expand_dims(original, axis=-1)
    if generated.ndim == 2:
        generated = np.expand_dims(generated, axis=-1)
    
    # PSNR
    psnr_value = psnr(original, generated, data_range=1)

     # SSIM 
    ssim_value, _ = ssim(
    original, 
    generated, 
    channel_axis=-1,  # 指定通道轴为最后一维
    data_range=1.0, 
    full=True, 
    win_size=7
)

    
    # VIF
    vif_value = vifp(original, generated)

    return {
        "PSNR": psnr_value,
        "SSIM": ssim_value,
        "VIF": vif_value,
    }



generator = Generator()
generator.manual_seed(42)

def adjust_image_size(image, target_size):
    """
    调整图像大小以匹配目标尺寸。
    :param image: 输入的PIL图像对象
    :param target_size: 目标尺寸 (宽度, 高度)
    :return: 调整大小后的PIL图像对象
    """
    return image.resize(target_size, Image.ANTIALIAS)
    
def tensor_to_image(tensor):
    """
    将PyTorch张量转换为PIL图像
    :param tensor: 输入的PyTorch张量 (C, H, W)
    :return: PIL图像对象
    """
    tensor = tensor.cpu().clone()  # 克隆张量以避免修改原张量
    tensor = tensor.squeeze(0)  # 移除batch维度
    # unnormalize = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225])
    # tensor = unnormalize(tensor)  # 反归一化（如果之前进行了归一化）
    # tensor = torch.clamp(tensor, 0, 1)  # 确保像素值在[0, 1]之间
    image = transforms.ToPILImage()(tensor)
    return image


def save_image(image, output_dir, filename):
    """
    保存图像到指定目录
    :param image: PIL图像对象
    :param output_dir: 输出目录路径
    :param filename: 文件名
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    image.save(os.path.join(output_dir, filename))

               

def test_model(image_type, model_path, output_dir="./output_images_sd"):
    """
    测试针对特定图像类型的模型并计算图像质量评估指标平均值
    """
    output_dir += "/" + image_type
    # 加载模型权重
    pipe.unet.load_state_dict(torch.load(model_path))
    pipe.unet.eval()  # 设置为评估模式

    original_image_dir = f"./dataset_test/original_image/{image_type}"
    mask_dir = f"./dataset_test/mask/{image_type}"

    test_dataset = InpaintingDataset(
        original_image_dir=original_image_dir,
        mask_dir=mask_dir,
        transform_img=transform_img,
        transform_mask=transform_mask
    )
    test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

    # 初始化累计值
    metrics_sum = {"PSNR": 0.0, "SSIM": 0.0, "VIF": 0.0}
    total_images = 0

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(test_dataloader, desc=f"Testing {image_type}")):
            images, masks = batch
            images, masks = images.cuda(), masks.cuda()

            prompt = [" "] * images.size(0)

            inputs = pipe.tokenizer(
                prompt,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
                max_length=pipe.tokenizer.model_max_length
            )

            for k, v in inputs.items():
                inputs[k] = v.cuda()

            prompt_embeds = pipe.text_encoder(
                inputs.input_ids, attention_mask=inputs.attention_mask
            ).last_hidden_state

            outputs = pipe(
                image=images,
                mask_image=masks,
                prompt_embeds=prompt_embeds,
                generator=generator,
                num_inference_steps=50,
                guidance_scale=7.5
            ).images

            if isinstance(outputs, (list, tuple)):
                outputs = torch.stack([transforms.ToTensor()(img) for img in outputs]).cuda()

            for i in range(outputs.size(0)):
                
                output_image = tensor_to_image(outputs[i])
                filename = f"generated_image_{batch_idx * len(images) + i + 1}.png"
                save_image(output_image, output_dir, filename)
                original_image = tensor_to_image(images[i])

                
                # 如果需要，调整输出图像大小以匹配原始图像
                if output_image.size != original_image.size:
                    original_image = adjust_image_size(original_image, output_image.size)
                    

                # 计算指标
                metrics = calculate_metrics(original_image, output_image)
                for key in metrics_sum:
                    metrics_sum[key] += metrics[key]
                total_images += 1
                

            del images, masks, outputs, prompt_embeds
            torch.cuda.empty_cache()
            gc.collect()
    
    # 计算平均值
    metrics_avg = {key: metrics_sum[key] / total_images for key in metrics_sum}

    # 保存平均值
    metrics_file = os.path.join(output_dir, f"{image_type}_average_metrics_sd.json")
    with open(metrics_file, "w") as f:
        json.dump(metrics_avg, f, indent=4)

    # 打印结果
    print(f"Average metrics for {image_type}:")
    print(metrics_avg)


if __name__ == "__main__":
    # 示例：测试 'face' 类型的模型
    image_types = ["face", "scenario", "street_scene_pairs", "texture"]
    for image_type in image_types:
        model_path = f'unet_model_{image_type}.pth'
        test_model(image_type=image_type, model_path=model_path)

  pipe.unet.load_state_dict(torch.load(model_path))


Loaded 200 image-mask pairs




  0%|          | 0/50 [00:00<?, ?it/s]

  return image.resize(target_size, Image.ANTIALIAS)
Testing face:   8%|██▏                          | 1/13 [01:42<20:25, 102.16s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing face:  15%|████▍                        | 2/13 [03:24<18:45, 102.33s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing face:  23%|██████▋                      | 3/13 [05:06<17:02, 102.22s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing face:  31%|████████▉                    | 4/13 [06:48<15:19, 102.16s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing face:  38%|███████████▏                 | 5/13 [08:31<13:37, 102.19s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing face:  46%|█████████████▍               | 6/13 [10:14<11:58, 102.68s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing face:  54%|███████████████▌             | 7/13 [12:03<10:28, 104.69s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing face:  62%|█████████████████▊           | 8/13 [13:50<08:47, 105.46s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing face:  69%|████████████████████         | 9/13 [15:34<06:59, 104.97s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing face:  77%|█████████████████████▌      | 10/13 [17:16<05:12, 104.15s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing face:  85%|███████████████████████▋    | 11/13 [19:00<03:28, 104.08s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing face:  92%|█████████████████████████▊  | 12/13 [20:43<01:43, 103.54s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing face: 100%|█████████████████████████████| 13/13 [21:34<00:00, 99.55s/it]


Average metrics for face:
{'PSNR': 7.264619312404474, 'SSIM': 0.24801716435700655, 'VIF': 0.020357388166735615}
Loaded 200 image-mask pairs


Testing scenario:   0%|                                  | 0/13 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing scenario:   8%|█▉                       | 1/13 [01:42<20:27, 102.32s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing scenario:  15%|███▊                     | 2/13 [03:24<18:45, 102.31s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing scenario:  23%|█████▊                   | 3/13 [05:06<17:02, 102.30s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing scenario:  31%|███████▋                 | 4/13 [06:49<15:20, 102.30s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing scenario:  38%|█████████▌               | 5/13 [08:31<13:38, 102.29s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing scenario:  46%|███████████▌             | 6/13 [10:13<11:55, 102.25s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing scenario:  54%|█████████████▍           | 7/13 [11:55<10:13, 102.28s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing scenario:  62%|███████████████▍         | 8/13 [13:40<08:34, 102.90s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing scenario:  69%|█████████████████▎       | 9/13 [15:22<06:50, 102.73s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing scenario:  77%|██████████████████▍     | 10/13 [17:04<05:07, 102.55s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing scenario:  85%|████████████████████▎   | 11/13 [18:46<03:24, 102.43s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing scenario:  92%|██████████████████████▏ | 12/13 [20:28<01:42, 102.33s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing scenario: 100%|█████████████████████████| 13/13 [21:19<00:00, 98.45s/it]


Average metrics for scenario:
{'PSNR': 7.231791672698777, 'SSIM': 0.17244896695949138, 'VIF': 0.026071536946030118}
Loaded 200 image-mask pairs


Testing street_scene_pairs:   0%|                        | 0/13 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing street_scene_pairs:   8%|█▏             | 1/13 [01:42<20:29, 102.46s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing street_scene_pairs:  15%|██▎            | 2/13 [03:24<18:45, 102.35s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing street_scene_pairs:  23%|███▍           | 3/13 [05:06<17:02, 102.30s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing street_scene_pairs:  31%|████▌          | 4/13 [06:49<15:20, 102.25s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing street_scene_pairs:  38%|█████▊         | 5/13 [08:31<13:38, 102.27s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing street_scene_pairs:  46%|██████▉        | 6/13 [10:13<11:55, 102.28s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing street_scene_pairs:  54%|████████       | 7/13 [11:55<10:13, 102.26s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing street_scene_pairs:  62%|█████████▏     | 8/13 [13:38<08:31, 102.26s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing street_scene_pairs:  69%|██████████▍    | 9/13 [15:20<06:49, 102.30s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing street_scene_pairs:  77%|██████████▊   | 10/13 [17:02<05:06, 102.27s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing street_scene_pairs:  85%|███████████▊  | 11/13 [18:45<03:24, 102.27s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing street_scene_pairs:  92%|████████████▉ | 12/13 [20:27<01:42, 102.24s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing street_scene_pairs: 100%|███████████████| 13/13 [21:18<00:00, 98.32s/it]


Average metrics for street_scene_pairs:
{'PSNR': 7.27517793389928, 'SSIM': 0.1873990525305271, 'VIF': 0.024900372341543617}
Loaded 200 image-mask pairs


Testing texture:   0%|                                   | 0/13 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing texture:   8%|██                        | 1/13 [01:42<20:24, 102.05s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing texture:  15%|████                      | 2/13 [03:24<18:42, 102.02s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing texture:  23%|██████                    | 3/13 [05:06<17:00, 102.04s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing texture:  31%|████████                  | 4/13 [06:48<15:18, 102.03s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing texture:  38%|██████████                | 5/13 [08:30<13:36, 102.01s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing texture:  46%|████████████              | 6/13 [10:12<11:53, 101.99s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing texture:  54%|██████████████            | 7/13 [11:54<10:12, 102.02s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing texture:  62%|████████████████          | 8/13 [13:36<08:30, 102.02s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing texture:  69%|██████████████████        | 9/13 [15:18<06:48, 102.02s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing texture:  77%|███████████████████▏     | 10/13 [17:00<05:05, 102.00s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing texture:  85%|█████████████████████▏   | 11/13 [18:42<03:23, 101.96s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing texture:  92%|███████████████████████  | 12/13 [20:23<01:41, 101.93s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

Testing texture: 100%|██████████████████████████| 13/13 [21:14<00:00, 98.05s/it]

Average metrics for texture:
{'PSNR': 8.125930449558334, 'SSIM': 0.09555807177952375, 'VIF': 0.0769037612807758}



