In [2]:
import torch.nn as nn

class SRCNN(nn.Module):
    def __init__(self, num_channels=3, base_channels=64, num_blocks=3):
        """
        Args:
            num_channels (int): 입력 이미지의 채널 수 (RGB=3, Grayscale=1)
            base_channels (int): 기본 피처맵 개수
            num_blocks (int): 중간 컨볼루션 블록의 개수
        """
        super(SRCNN, self).__init__()
        
        # 첫 번째 레이어: 패치 추출과 표현
        self.first_layer = nn.Sequential(
            nn.Conv2d(num_channels, base_channels, kernel_size=9, padding=4),
            nn.ReLU(inplace=True)
        )
        
        # 중간 레이어: 비선형 매핑
        middle_layers = []
        for _ in range(num_blocks):
            middle_layers.extend([
                nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            ])
        self.middle_layers = nn.Sequential(*middle_layers)
        
        # 마지막 레이어: 복원
        self.last_layer = nn.Conv2d(base_channels, num_channels, kernel_size=5, padding=2)
        
        # 가중치 초기화
        self._initialize_weights()
        
    def forward(self, x):
        out = self.first_layer(x)
        out = self.middle_layers(out)
        out = self.last_layer(out)
        return out
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


In [None]:
# ... existing SRCNN class code ...

from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from torchvision import transforms

class SRDataset(Dataset):
    def __init__(self, lr_dir, hr_dir):
        """
        Args:
            lr_dir (str): 저해상도 이미지가 있는 디렉토리 경로
            hr_dir (str): 고해상도 이미지가 있는 디렉토리 경로
        """
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.image_files = os.listdir(lr_dir)
        
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        
        lr_img = Image.open(os.path.join(self.lr_dir, img_name))
        hr_img = Image.open(os.path.join(self.hr_dir, img_name))
        
        lr_tensor = self.transform(lr_img)
        hr_tensor = self.transform(hr_img)
        
        return lr_tensor, hr_tensor

# DataLoader 설정
def get_dataloader(lr_dir, hr_dir, batch_size=32, num_workers=4):
    dataset = SRDataset(lr_dir, hr_dir)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    return dataloader

# 사용 예시
lr_folder = "lr"
hr_folder = "hr"
train_dataloader = get_dataloader(lr_folder, hr_folder)


In [None]:

# 학습에 사용할 때는 다음과 같이 사용할 수 있습니다
for lr_images, hr_images in train_dataloader:
    # lr_images.shape: [batch_size, channels, height, width]
    # hr_images.shape: [batch_size, channels, height, width]
    # 여기서 모델 학습 로직을 구현
    pass