In [17]:
# %pip install skimage

In [6]:
import torch
from torchvision import models, transforms
from torchvision.utils import make_grid
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
# from skimage.metrics import structural_similarity as ssim
# from pytorch_fid import fid_score
# import lpips

device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

Using device: cuda


### 加载模型

In [8]:
from diffusers import DDPMScheduler, UNet2DModel

# 实例化模型
net = UNet2DModel(
    sample_size=128,
    in_channels=1,
    out_channels=1,
    block_out_channels=(32, 64, 64),
    down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
    up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
).to(device)

# 加载checkpoints
checkpoint = torch.load('checkpoints_0708/best_model.pth', map_location=device)
# print(checkpoint.keys())

# 切换到eval
net.eval()

UNet2DModel(
  (conv_in): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=32, out_features=128, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=128, out_features=128, bias=True)
  )
  (down_blocks): ModuleList(
    (0): DownBlock2D(
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock2D(
          (norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
          (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
          (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (nonlinearity): SiLU()
        )
      )
      (downsamplers): ModuleList(
        (0): Downsample2D(
          (conv): Conv2d(32, 32, ker

In [9]:
# 测试加载
dummy_input = torch.randn(36, 1, 128, 128).to(device)
with torch.no_grad():
    output = net(dummy_input, timestep=0).sample
print(output.shape)  # [36, 1, 128, 128]

torch.Size([36, 1, 128, 128])


In [16]:
from torch.cuda.amp import GradScaler, autocast
from diffusers import DDPMScheduler, UNet2DModel
from tqdm.auto import tqdm

# 创建 scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

In [17]:
# 换scheduler
# 替换为 DDIMScheduler 或 DPMSolverMultistepScheduler
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler
noise_scheduler_ddim = DDIMScheduler.from_config(noise_scheduler.config) # 50步
noise_scheduler_dpm = DPMSolverMultistepScheduler.from_config(noise_scheduler.config) #20-30步
noise_scheduler_ddim.set_timesteps(50)
noise_scheduler_dpm.set_timesteps(30)

In [18]:
# Prepare random x to start from, plus desired labels y
num_classes = 36
samples_per_class = 8  # 每类生成8张
x = torch.randn(num_classes * samples_per_class, 1, 128, 128).to(device)
y = torch.tensor([[i] * samples_per_class for i in range(num_classes)]).flatten().to(device)

# Reverse diffusion (sampling loop)
for t in tqdm(noise_scheduler_dpm.timesteps):
    with torch.no_grad():
        residual = net(x, t, y)  # 条件在 class label
        # print(residual.shape)
    x = noise_scheduler_dpm.step(residual, t, x).prev_sample

# Postprocess to [0, 1]
samples = (x.detach().cpu().clip(-1, 1) + 1) / 2

# Show grid of generated images
grid = torchvision.utils.make_grid(samples, nrow=samples_per_class)
plt.figure(figsize=(15, 40))
plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap="gray")
plt.axis("off")
plt.title("Generated Samples (Conditioned on Class Labels)")
plt.show()


  0%|          | 0/30 [00:00<?, ?it/s]

ValueError: class_embedding needs to be initialized in order to use class conditioning

## 1.分类器

In [None]:
# 预处理 & 加载分类器
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])
resnet50 = models.resnet50(pretrained=True).eval()

# LPIPS 模型
lpips_fn = lpips.LPIPS(net='alex').eval()

def get_predicted_class(img_pil):
    """用 pretrained 预测图片类别"""
    input_tensor = preprocess(img_pil).unsqueeze(0)
    with torch.no_grad():
        output = resnet50(input_tensor)
    return output.argmax(dim=1).item()

def compute_metrics(original_images, generated_images, target_classes, class_labels):
    """
    original_images, generated_images: list of PIL.Image
    target_classes: list of int (target class index)
    class_labels: dict {class_index: class_name}
    """
    correct = 0
    ssim_scores, psnr_scores, lpips_scores = [], [], []
    
    for orig, gen, target in zip(original_images, generated_images, target_classes):
        # 分类器评估
        pred_class = get_predicted_class(gen)
        if pred_class == target:
            correct += 1
        
        # SSIM
        ssim_val = ssim(np.array(orig.convert('L')), np.array(gen.convert('L')))
        ssim_scores.append(ssim_val)
        
        # PSNR
        mse = np.mean((np.array(orig) - np.array(gen)) ** 2)
        psnr_val = 20 * np.log10(255.0 / np.sqrt(mse)) if mse > 0 else 100
        psnr_scores.append(psnr_val)
        
        # LPIPS
        lpips_val = lpips_fn(preprocess(orig).unsqueeze(0), preprocess(gen).unsqueeze(0))
        lpips_scores.append(lpips_val.item())
        
        # 可视化
        fig, axes = plt.subplots(1, 3, figsize=(12, 4))
        axes[0].imshow(orig)
        axes[0].set_title(f"Original: {class_labels[target]}")
        axes[1].imshow(gen)
        axes[1].set_title(f"Generated (Pred: {class_labels[pred_class]})")
        diff = np.abs(np.array(orig).astype(float) - np.array(gen).astype(float))
        axes[2].imshow(diff.astype(np.uint8))
        axes[2].set_title("Difference Map")
        for ax in axes: ax.axis('off')
        plt.show()
    
    accuracy = correct / len(generated_images)
    print(f"\n Class-conditioned Accuracy: {accuracy:.2%}")
    print(f"Avg SSIM: {np.mean(ssim_scores):.4f}")
    print(f"Avg PSNR: {np.mean(psnr_scores):.2f} dB")
    print(f"Avg LPIPS: {np.mean(lpips_scores):.4f}")
    
    # 可选: FID (需要把图片存到目录)
    # fid_value = fid_score.calculate_fid_given_paths([gen_dir, orig_dir], batch_size=50, device='cuda', dims=2048)
    # print(f"FID: {fid_value:.2f}")

    return {
        "accuracy": accuracy,
        "ssim": np.mean(ssim_scores),
        "psnr": np.mean(psnr_scores),
        "lpips": np.mean(lpips_scores)
    }
