In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from diffusers import UNet2DModel
from timm.models import resnet18
import lpips
from skimage.metrics import structural_similarity as ssim
import numpy as np
from PIL import Image
import os

# 설정
class Config:
    image_size = 256
    train_batch_size = 32
    test_batch_size = 32
    num_epochs = 10
    learning_rate = 1e-4
    train_ratio = 0.8 # 학습 데이터 비율
    data_dir = "DATA" # 데이터 경로

config = Config()

import torch.nn.functional as F

class Autoencoder(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.encoder = nn.Sequential(*list(backbone.children())[:-2])  # ResNet18에서 마지막 두 레이어 제거
        self.channel_reduction = nn.Conv2d(512, 3, kernel_size=1)  # 채널 축소 (512 → 3)
        self.decoder = UNet2DModel(
            sample_size=256,  # 입력 이미지 크기
            in_channels=3,
            out_channels=3,
            layers_per_block=2,
            block_out_channels=(128, 256, 512),
            down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D"),
            up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D"),
        )

    def forward(self, x):
        latent = self.encoder(x)
        latent = self.channel_reduction(latent)  # 채널 수 축소 (512 → 3)
        batch_size = latent.shape[0]
        timestep = torch.zeros(batch_size, dtype=torch.long, device=latent.device)  # timestep 설정
        decoded = self.decoder(latent, timestep).sample  # U-Net 디코더
        
        # 출력 크기를 원본 크기로 보정
        decoded = F.interpolate(decoded, size=x.shape[2:], mode="bilinear", align_corners=False)
        return decoded

# 모델 및 장치 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backbone = resnet18(pretrained=True)
model = Autoencoder(backbone).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
loss_fn = nn.MSELoss()

# 손실 함수 및 오류 계산 함수 (이전 코드와 동일)
def ssim(X, Y, data_range):
    C1 = (0.01 * data_range)**2
    C2 = (0.03 * data_range)**2

    mu_x = nn.AvgPool2d(kernel_size=3, stride=1)(X)
    mu_y = nn.AvgPool2d(kernel_size=3, stride=1)(Y)

    sigma_x = nn.AvgPool2d(kernel_size=3, stride=1)(X**2) - mu_x**2
    sigma_y = nn.AvgPool2d(kernel_size=3, stride=1)(Y**2) - mu_y**2
    sigma_xy = nn.AvgPool2d(kernel_size=3, stride=1)(X*Y) - mu_x*mu_y

    SSIM = (2*mu_x*mu_y + C1)*(2*sigma_xy + C2) / (mu_x**2 + mu_y**2 + C1)*(sigma_x + sigma_y + C2)
    return SSIM.mean()

def calculate_reconstruction_error(original_image, reconstructed_image):
    # MSE (Pixel-wise 차이)
    mse = nn.MSELoss()(original_image, reconstructed_image).item()

    # PSNR (Peak Signal-to-Noise Ratio)
    mse_np = mse # tensor to np
    psnr = 10 * np.log10(1 / mse_np) if mse_np > 0 else float('inf')
    
    # SSIM (Structural Similarity)
    original_image_np = original_image.cpu().detach().numpy()[0].transpose(1,2,0)
    reconstructed_image_np = reconstructed_image.cpu().detach().numpy()[0].transpose(1,2,0)
    original_image_np = (original_image_np * 0.5 + 0.5)
    reconstructed_image_np = (reconstructed_image_np * 0.5 + 0.5)

    ssim_value = ssim(original_image_np, reconstructed_image_np, channel_axis=2, data_range = 1)
    
    #LPIPS
    perceptual_loss_fn = lpips.LPIPS(net='vgg').to("cuda") # VGG 기반 Perceptual Loss
    perceptual_loss = perceptual_loss_fn(original_image, reconstructed_image).item()
    
    return mse, psnr, ssim_value, perceptual_loss

# 데이터 로더
transform = transforms.Compose([
    transforms.Resize((config.image_size, config.image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = datasets.ImageFolder(root=config.data_dir, transform=transform)

# 학습/테스트 데이터 분할
train_size = int(config.train_ratio * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=os.cpu_count())
test_dataloader = DataLoader(test_dataset, batch_size=config.test_batch_size, shuffle=False, num_workers=os.cpu_count())

# 학습 루프
for epoch in range(config.num_epochs):
    model.train()
    for batch in train_dataloader:
        images, _ = batch # _는 클래스 레이블 (사용하지 않음)
        images = images.to(device)
        optimizer.zero_grad()
        reconstructed_images = model(images)
        loss = loss_fn(reconstructed_images, images)
        loss.backward()
        optimizer.step()

    # 검증 (테스트 데이터셋의 일부 사용)
    model.eval()
    test_errors = []
    with torch.no_grad():
        for batch in test_dataloader:
            images, _ = batch
            images = images.to(device)
            reconstructed_images = model(images)
            for i in range(len(images)):
                mse, psnr, ssim_value, perceptual_loss = calculate_reconstruction_error(images[i].unsqueeze(0), reconstructed_images[i].unsqueeze(0))
                test_errors.append([mse, psnr, ssim_value, perceptual_loss])

    # 평균 오류 계산
    test_errors = np.array(test_errors)
    avg_mse, avg_psnr, avg_ssim, avg_lpips = np.mean(test_errors, axis=0)
    print(f"Epoch {epoch+1}, Test MSE: {avg_mse}, PSNR: {avg_psnr}, SSIM: {avg_ssim}, LPIPS: {avg_lpips}")

# 임계값 설정 (테스트 데이터셋 전체 사용)
model.eval()
all_test_errors = []
with torch.no_grad():
    for batch in test_dataloader:
        images, _ = batch
        images = images.to(device)
        reconstructed_images = model(images)
        for i in range(len(images)):
            mse, psnr, ssim_value, perceptual_loss = calculate_reconstruction_error(images[i].unsqueeze(0), reconstructed_images[i].unsqueeze(0))
            all_test_errors.append([mse, psnr, ssim_value, perceptual_loss])

all_test_errors = np.array(all_test_errors)

# 임계값 설정 (예시: MSE 기준)
mse_threshold = np.percentile(all_test_errors[:, 0], 95) # 상위 5%를 임계값으로 설정 (조정 가능)
print(f"MSE Threshold: {mse_threshold}")

# 새로운 이미지에 대한 분류 (추론)
def classify_image(image_path, threshold=mse_threshold):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        reconstructed_image = model(image)
        mse, psnr, ssim_value, perceptual_loss = calculate_reconstruction_error(image, reconstructed_image)
    if mse > threshold:
        return "Unseen Class", mse, psnr, ssim_value, perceptual_loss
    else:
        return "Seen Class", mse, psnr, ssim_value, perceptual_loss

# 새로운 이미지 테스트
test_image_path = "160_master.jpeg" # 테스트할 이미지 경로
classification_result, mse, psnr, ssim_value, perceptual_loss = classify_image(test_image_path)
print(f"Image Classification: {classification_result}, MSE: {mse}, PSNR: {psnr}, SSIM: {ssim_value}, LPIPS: {perceptual_loss}")

TypeError: ssim() got an unexpected keyword argument 'channel_axis'