# 간단한 이미지 복원 모델 구현 및 학습

## 1. 데이터 준비
1. CIFAR-10 데이터셋 다운로드 및 로드
2. 저해상도 이미지 및 노이즈 추가
3. 데이터로더 구성

In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
# 데이터 변환 설정
transform = transforms.Compose([
    transforms.ToTensor()
])

# CIFAR-10 데이터셋 로드
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# 저해상도 이미지 생성 함수
def create_low_resolution(img, scale_factor):
    low_res = transforms.functional.resize(img, [img.size(1)//scale_factor, img.size(2)//scale_factor])
    low_res = transforms.functional.resize(low_res, [img.size(1), img.size(2)])
    return low_res

# 노이즈 추가 함수
def add_noise(img, noise_factor=0.5):
    noise = torch.randn(img.size()) * noise_factor
    noisy_img = img + noise
    noisy_img = torch.clamp(noisy_img, 0., 1.)
    return noisy_img

In [4]:
# 데이터셋 클래스 정의
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None, degradation='downscale', scale_factor=2):
        self.dataset = dataset
        self.transform = transform
        self.degradation = degradation
        self.scale_factor = scale_factor
        
    def __getitem__(self, index):
        img, _ = self.dataset[index]
        if self.degradation == 'downscale':
            degraded_img = create_low_resolution(img, self.scale_factor)
        elif self.degradation == 'noise':
            degraded_img = add_noise(img)
        else:
            raise ValueError('Invalid degradation type')
        
        if self.transform:
            degraded_img = self.transform(degraded_img)
            img = self.transform(img)
            
        return degraded_img, img
    
    def __len__(self):
        return len(self.dataset)

In [5]:
# 학습 및 테스트 데이터셋 생성 (다운스케일링 예시)
train_data = ImageDataset(train_dataset, degradation='downscale', scale_factor=2)
test_data = ImageDataset(test_dataset, degradation='downscale', scale_factor=2)

# 데이터로더 생성
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

## 2. 모델 설계
- 간단한 CNN 모델 설계

In [6]:
import torch.nn as nn

In [7]:
class SimpleImageRestorationModel(nn.Module):
    def __init__(self):
        super(SimpleImageRestorationModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=2)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

## 3. 모델 학습
1. 손실 함수 및 옵티마이저 설정
2. 학습 루프 구현

In [8]:
model = SimpleImageRestorationModel()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [9]:
# 학습 루프
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for degraded_imgs, original_imgs in train_loader:
        outputs = model(degraded_imgs)
        loss = criterion(outputs, original_imgs)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    avg_loss = epoch_loss / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')

Epoch [1/10], Loss: 0.0062
Epoch [2/10], Loss: 0.0027
Epoch [3/10], Loss: 0.0025
Epoch [4/10], Loss: 0.0023
Epoch [5/10], Loss: 0.0023
Epoch [6/10], Loss: 0.0022
Epoch [7/10], Loss: 0.0021
Epoch [8/10], Loss: 0.0021
Epoch [9/10], Loss: 0.0021
Epoch [10/10], Loss: 0.0020


## 4. 모델 평가
1. 평가 지표 계산
2. 테스트 데이터셋을 사용하여 모델 성능 평가

In [10]:
import math

In [13]:
def calculate_psnr(img1, img2):
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return 100
    psnr = 20 * torch.log10(torch.tensor(1.0 / math.sqrt(mse)))
    return psnr

In [14]:
model.eval()
total_psnr = 0
with torch.no_grad():
    for degraded_imgs, original_imgs in test_loader:
        outputs = model(degraded_imgs)
        psnr = calculate_psnr(outputs, original_imgs)
        total_psnr += psnr
avg_psnr = total_psnr / len(test_loader)
print(f'Average PSNR: {avg_psnr:.2f} dB')

Average PSNR: 26.95 dB


## 5. 결과 시각화
- 복원된 이미지와 원본 이미지 비교

In [15]:
import matplotlib.pyplot as plt

In [17]:
# 예시 이미지 시각화
dataiter = iter(test_loader)
degraded_imgs, original_imgs = next(dataiter)

# 모델을 사용하여 복원
with torch.no_grad():
    restored_imgs = model(degraded_imgs)

In [18]:
# 첫 번째 이미지 선택
idx = 0
degraded_img = degraded_imgs[idx].permute(1, 2, 0).numpy()
restored_img = restored_imgs[idx].permute(1, 2, 0).numpy()
original_img = original_imgs[idx].permute(1, 2, 0).numpy()

In [19]:
# 이미지 시각화
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.title('Degraded Image')
plt.imshow(degraded_img)
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title('Restored Image')
plt.imshow(restored_img)
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title('Original Image')
plt.imshow(original_img)
plt.axis('off')

plt.show()

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.07597301..1.0022018].


: 