# Recolor_Restore_image


### 라이브러리 임포트

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import pandas as pd
import torchvision.transforms as transforms
import numpy as np
from torch.optim import Adam

### 데이터셋 클래스
- 손상된 이미지와 원본 이미지 쌍을 로드
- CSV 파일에서 이미지 경로를 읽어옴
- 이미지를 텐서로 변환하고 정규화 ([0,1] -> [-1,1])

In [None]:
class DamagedImageDataset(Dataset):
    def __init__(self, csv_path, num_samples=None):
        """
        Args:
            csv_path: train.csv 파일 경로
            num_samples: 테스트용 샘플 수 (None이면 전체 데이터 사용)
        """
        # CSV 파일 읽기
        self.data = pd.read_csv(csv_path)
        if num_samples:
            self.data = self.data.head(num_samples)
            
        # 기본 디렉토리 경로 (csv 파일이 있는 디렉토리)
        self.root_dir = os.path.dirname(csv_path)
        
        # 이미지 전처리
        self.transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),  # 흑백 변환
            transforms.ToTensor(),   # [0, 255] -> [0, 1]
            transforms.Normalize([0.5], [0.5])  # [0, 1] -> [-1, 1]
        ])
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        # 경로 가져오기
        input_path = os.path.join(self.root_dir, 
                                  self.data.iloc[idx]['input_image_path'].replace('./', ''))
        gt_path = os.path.join(self.root_dir, 
                               self.data.iloc[idx]['gt_image_path'].replace('./', ''))
        
        # 이미지 로드
        input_img = Image.open(input_path).convert("L")  # 흑백 변환
        gt_img = Image.open(gt_path).convert("L")  # 흑백 변환
        
        # 전처리 적용
        input_tensor = self.transform(input_img)
        gt_tensor = self.transform(gt_img)
        
        return input_tensor, gt_tensor

    def create_random_mask(size=256, mask_size=128):
        """중앙에 랜덤한 마스크 생성"""
        mask = torch.ones((size, size))
        x = np.random.randint(0, size - mask_size)
        y = np.random.randint(0, size - mask_size)
        mask[x:x+mask_size, y:y+mask_size] = 0
        return mask


### 모델

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            # 입력: 1 x 512 x 512
            nn.Conv2d(1, 64, 4, stride=2, padding=1),    # 64 x 256 x 256
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 128 x 128 x 128
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(128, 256, 4, stride=2, padding=1), # 256 x 64 x 64
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(256, 512, 4, stride=2, padding=1), # 512 x 32 x 32
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(512, 1024, 4, stride=2, padding=1), # 1024 x 16 x 16
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2),
        )
        
    def forward(self, x):
        return self.encoder(x)

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder = nn.Sequential(
            # 입력: 1024 x 16 x 16
            nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1), # 512 x 32 x 32
            nn.BatchNorm2d(512),
            nn.ReLU(),
            
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),  # 256 x 64 x 64
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),  # 128 x 128 x 128
            nn.BatchNorm2d(128),
            nn.ReLU(),
            
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),   # 64 x 256 x 256
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1),     # 1 x 512 x 512
            nn.Tanh()
        )
        
    def forward(self, x):
        return self.decoder(x)

class ContextEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        
    def forward(self, x):
        features = self.encoder(x)
        output = self.decoder(features)
        return output


### 이미지 저장

In [None]:
import matplotlib.pyplot as plt
import os

def save_images(damaged, output, original, epoch, save_dir='results'):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        
    # 첫 번째 이미지에 대해서만 저장 (배치 중 하나 선택)
    damaged_np = (damaged[0].cpu().detach().numpy().transpose(1,2,0) * 0.5 + 0.5).clip(0,1)
    output_np = (output[0].cpu().detach().numpy().transpose(1,2,0) * 0.5 + 0.5).clip(0,1)
    original_np = (original[0].cpu().detach().numpy().transpose(1,2,0) * 0.5 + 0.5).clip(0,1)
    
    fig, axes = plt.subplots(1, 3, figsize=(12,4))
    axes[0].imshow(damaged_np, cmap='gray')
    axes[0].set_title('Damaged')
    axes[0].axis('off')
    
    axes[1].imshow(output_np, cmap='gray')
    axes[1].set_title('Output')
    axes[1].axis('off')
    
    axes[2].imshow(original_np, cmap='gray')
    axes[2].set_title('Original')
    axes[2].axis('off')
    
    # 이미지 파일로 저장
    save_path = os.path.join(save_dir, f'epoch_{epoch}.png')
    plt.savefig(save_path, bbox_inches='tight')
    plt.close(fig)


### 설정 클래스 정의

In [None]:
from dataclasses import dataclass, field
from typing import Any

@dataclass
class ModelConfig:
    # 모델 아키텍처 설정
    input_channels: int = 1
    output_channels: int = 1
    initial_features: int = 64
    latent_dim: int = 1024
    image_size: int = 512

@dataclass
class TrainConfig:
    # 학습 관련 설정
    num_epochs: int = 100        
    batch_size: int = 32         
    learning_rate: float = 0.0002  
    beta1: float = 0.5          
    beta2: float = 0.999        
    num_workers: int = 4        
    
    # 저장 관련 설정
    results_dir: str = 'results'
    save_frequency: int = 100   
    print_frequency: int = 50   

@dataclass
class Config:
    model: ModelConfig = field(default_factory=ModelConfig)
    train: TrainConfig = field(default_factory=TrainConfig)
    device: torch.device = field(default_factory=lambda: torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    
    def __post_init__(self):
        # 필요한 디렉토리 생성
        os.makedirs(self.train.results_dir, exist_ok=True)

# 설정 객체 생성
config = Config()

### 학습

In [None]:
from tqdm import tqdm  

def train(model, train_loader, val_loader, criterion, optimizer, device, config):
    model.train()
    best_val_loss = float('inf')
    
    for epoch in range(config.train.num_epochs):
        # 학습 단계
        model.train()
        total_train_loss = 0
        
        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), 
                          desc=f"Train Epoch {epoch+1}/{config.train.num_epochs}")
        
        for batch_idx, (damaged, original) in progress_bar:
            damaged = damaged.to(device)
            original = original.to(device)
            
            optimizer.zero_grad()
            output = model(damaged)
            loss = criterion(output, original)
            
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")
            
            if batch_idx % config.train.save_frequency == 0:
                save_images(damaged, output, original, 
                          epoch=epoch, 
                          save_dir=config.train.results_dir)
        
        avg_train_loss = total_train_loss / len(train_loader)
        
        # 검증 단계
        model.eval()
        total_val_loss = 0
        
        with torch.no_grad():
            for damaged, original in val_loader:
                damaged = damaged.to(device)
                original = original.to(device)
                
                output = model(damaged)
                val_loss = criterion(output, original)
                total_val_loss += val_loss.item()
        
        avg_val_loss = total_val_loss / len(val_loader)
        
        # 현재 에포크의 학습/검증 결과 출력
        print(f'Epoch [{epoch+1}/{config.train.num_epochs}]')
        print(f'Training Loss: {avg_train_loss:.4f}')
        print(f'Validation Loss: {avg_val_loss:.4f}')
        
        # 최적의 모델 저장
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 
                      os.path.join(config.train.results_dir, 'best_model.pth'))

def main():
    # 설정 로드
    config = Config()
    device = config.device
    print(f'Using device: {device}')
    
    # 전체 데이터셋 생성
    full_dataset = DamagedImageDataset('data/train.csv', num_samples=10)
    
    # 학습/검증 데이터 분할 (80:20)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    train_dataset, val_dataset = random_split(
        full_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    # 데이터로더 생성
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.train.batch_size,
        shuffle=True,
        num_workers=config.train.num_workers
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.train.batch_size,
        shuffle=False,
        num_workers=config.train.num_workers
    )

    # 모델 초기화
    model = ContextEncoder().to(device)
    criterion = nn.MSELoss()
    optimizer = Adam(
        model.parameters(),
        lr=config.train.learning_rate,
        betas=(config.train.beta1, config.train.beta2)
    )
    
    # 학습 시작
    train(model, train_loader, val_loader, criterion, optimizer, device, config)

if __name__ == '__main__':
    main()