# 프라이버시 안전한 데이터셋을 활용한 고품질 이미지 복원

## 1. 데이터 준비
1. 고품질 이미지 수집
2. 열화 이미지 생성
3. 데이터셋 구성

In [1]:
import os
from PIL import Image, ImageFilter
import random
import numpy as np
import torch
from torchvision import transforms

In [2]:
# 데이터 디렉토리 설정
hq_image_dir = './data/archive/train/'
lq_image_dir = './data/unsplash_images/lq/train/'
os.makedirs(lq_image_dir, exist_ok=True)

In [3]:
# 열화 함수 정의
def degrade_image(image):
    # 랜덤으로 열화 적용
    if random.random() < 0.5:
        # 가우시안 블러
        radius = random.uniform(1, 3)
        image = image.filter(ImageFilter.GaussianBlur(radius=radius))
    if random.random() < 0.5:
        # 노이즈 추가
        noise = np.random.normal(0, 25, (image.size[1], image.size[0], 3))
        noise = Image.fromarray(noise.astype('uint8'), 'RGB')
        image = Image.blend(image, noise, alpha=0.5)
    if random.random() < 0.5:
        # JPEG 압축
        try:
            image.save('temp.jpg', 'JPEG', quality=random.randint(10, 50))
            with Image.open('temp.jpg') as degraded_image:
                image = degraded_image.copy()
        finally:
            if os.path.exists('temp.jpg'):
                os.remove('temp.jpg')
    if random.random() < 0.5:
        # 해상도 저하
        scale_factor = random.uniform(0.5, 0.8)
        new_size = (int(image.size[0] * scale_factor), int(image.size[1]*scale_factor))
        image = image.resize(new_size, Image.BICUBIC)
        image = image.resize(image.size, Image.BICUBIC)
    return image

In [19]:
# LQ 이미지 생성
for filename in os.listdir(hq_image_dir):
    hq_image_path = os.path.join(hq_image_dir, filename)
    lq_image_path = os.path.join(lq_image_dir, filename)
    image = Image.open(hq_image_path).convert('RGB')
    degraded_image = degrade_image(image)
    degraded_image.save(lq_image_path)

In [4]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [5]:
class ImageDegradationDataset(Dataset):
    def __init__(self, hq_dir, lq_dir, transform=None):
        self.hq_dir = hq_dir
        self.lq_dir = lq_dir
        self.hq_images = sorted(os.listdir(hq_dir))
        self.lq_images = sorted(os.listdir(lq_dir))
        self.transform = transform
        
    def __len__(self):
        return len(self.hq_images)
    
    def __getitem__(self, idx):
        hq_path = os.path.join(self.hq_dir, self.hq_images[idx])
        lq_path = os.path.join(self.lq_dir, self.lq_images[idx])
        
        hq_image = Image.open(hq_path).convert('RGB')
        lq_image = Image.open(lq_path).convert('RGB')
        
        if self.transform:
            hq_image = self.transform(hq_image)
            lq_image = self.transform(lq_image)
            
        return lq_image, hq_image

In [6]:
train_dataset = ImageDegradationDataset(hq_dir=hq_image_dir, lq_dir=lq_image_dir, transform=None)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2)

In [None]:
next(iter(train_loader))

## 2. 모델 설계
1. 네트워크 구조 선택
2. Mixture of Experts(MoE) 구조 도입

In [12]:
import torch.nn as nn

In [13]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.body = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        )
        
    def forward(self, x):
        res = self.body(x)
        return x + res

In [14]:
class EDSR(nn.Module):
    def __init__(self, num_channels=3, num_feats=64, num_blocks=16, num_experts=4):
        super(EDSR, self).__init__()
        self.head = nn.Conv2d(num_channels, num_feats, kernel_size=3, padding=1)
        self.body = nn.ModuleList([ResidualBlock(num_feats) for _ in range(num_blocks)])
        self.tail = nn.Conv2d(num_feats, num_channels, kernel_size=3, padding=1)
        self.num_experts = num_experts
        self.router = nn.Sequential(
            nn.Conv2d(num_channels, num_experts, kernel_size=1),
            nn.Softmax(dim=1)
        )
        self.experts = nn.ModuleList([ResidualBlock(num_feats) for _ in range(num_experts)])
        
    def forward(self, x):
        feat = self.head(x)
        # 라우터를 통해 전문가 가중치 계산
        weights = self.router(x)
        # 각 전문가의 출력 계산 및 가중 합산
        expert_out = 0
        for i in range(self.num_experts):
            expert_feat = self.sxperts[i](feat)
            weight = weights[:, i:i+1, :, :]
            expert_out += expert_feat * weight
        res = expert_out
        for block in self.body:
            res = block(res)
        res += feat
        out = self.tail(res)
        return out
    
model = EDSR()

## 3. 모델 학습
1. 손실 함수 및 옵타마이저 설정
2. 학습 루프 구현

In [15]:
from torchvision import models

In [16]:
# Perceptual Loss를 위한 VGG 모델
vgg = models.vgg19(pretrained=True).features[:35].eval()
for param in vgg.parameters():
    param.requires_grad = False
    
def perceptual_loss(output, target):
    output_vgg = vgg(output)
    target_vgg = vgg(target)
    loss = nn.functional.l1_loss(output_vgg, target_vgg)
    return vgg



In [17]:
# 손실 함수 및 옵티마이저 설정
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# 학습 루프
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for lq_imgs, hq_imgs in train_loader:
        lq_imgs = lq_imgs.to(device)
        hq_imgs = hq_imgs.to(device)
        outputs = model(lq_imgs)
        loss1 = criterion(outputs, hq_imgs)
        loss2 = perceptual_loss(outputs, hq_imgs)
        loss = loss1 + 0.01 * loss2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    avg_loss = epoch_loss / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')
    scheduler.step()

## 4. 모델 평가
1. 평가 지표 계산
2. 테스트 데이터셋을 사용하여 모델 성능 평가

In [None]:
from lpips import LPIPS

In [None]:
lpips_loss_fn = LPIPS(net='alex').to(device)

In [None]:
model.eval()
total_psnr = 0
total_ssim = 0
total_lpips = 0
with torch.no_grad():
    for lq_imgs, hq_imgs in test_loader:
        lq_imgs = lq_imgs.to(device)
        hq_imgs = lq_imgs.to(device)
        outputs = model(lq_imgs)
        # PSNR 계산
        mse = torch.mean((outputs - hq_imgs) ** 2)
        psnr = 20 * torch.log10(1.0 / torch.sqrt(mse))
        total_psnr += psnr.item()
        # SSIM 계산
        ssim = pytorch_ssim.ssim(outputs, hq_imgs).item()
        total_ssim += ssim
        # LPIPS 계산
        lpips_value = lpips_loss_fn(outputs, hq_imgs).mean().item()
        total_lpips += lpips_value
avg_psnr = total_psnr / len(test_loader)
avg_ssim = total_ssim / len(test_loader)
avg_lpips = total_lpips / len(test_loader)
print(f'Average PSNR: {avg_psnr:.2f} dB')
print(f'Average SSIM: {avg_ssim:.4f}')
print(f'Average LPIPS: {avg_lpips:.4f}')

## 5. 결과 분석 및 개선
1. 결과 시각화
2. 모델 개선 방안 제시

In [None]:
# 결과 시각화
import matplotlib.pyplot as plt

In [None]:
model.eval()
dataiter = iter(test_loader)
lq_imgs, hq_imgs = next(dataiter)
lq_imgs = lq_imgs.to(device)
hq_imgs = hq_imgs.to(device)
with torch.no_grad():
    outputs = model(lq_imgs)

In [None]:
# 첫 번째 이미지 시각화
idx = 0
lq_img = lq_imgs[idx].cpu().permute(1, 2, 0).numpy()
output_img = outputs[idx].cpu().permute(1, 2, 0).numpy()
hq_img = hq_imgs[idx].cpu().permute(1, 2, 0).numpy()

plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.title('Low-Quality Image')
plt.imshow(lq_img)
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title('Restored Image')
plt.imshow(output_img)
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title('High-Quality Image')
plt.imshow(hq_img)
plt.axis('off')

plt.show()