In [None]:
import numpy as np
import matplotlib.pyplot as plt
import math
import cv2
from PIL import Image
import glob
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
import torch.optim as optim
from torch.utils.data import DataLoader

from google.colab import drive
drive.mount('/content/drive')


In [None]:
# 학습 데이터 압축 해제
import os
import zipfile

zip_path = '/content/drive/MyDrive/Colab Notebooks/2025_1 딥러닝/new_COCO2.zip'

zip_ref = zipfile.ZipFile(zip_path, 'r')
zip_ref.extractall('/dataset')
zip_ref.close()

In [None]:
# psnr, mse 계산 함수
# psnr을 계산하기 위해 파이토치 텐서를 넘파이로 변환하는 과정 필요
def compute_psnr(img1,img2):
    # pytorch 내장 gpu연산을 통해 연산 속도. gpu상에서 텐서 연산을 바로 진행
    # gpu->cpu 복사후 numpy로 처리하는 것보다 훨씬 속도가 빠름
    mse = torch.mean((img1 - img2)**2) #mse 계산식
    psnr = 20 * torch.log10(255.0 / torch.sqrt(mse)) #psnr 계산식 using numpy log10 and sqrt
    return psnr

In [None]:
from torchvision import transforms

class VariancePatchDataset(Dataset):
    def __init__(self, img_dir, K):
        assert 16384 % (K*K) == 0
        self.K = K
        self.M = 16384 // (K*K) # 메모리 128x128을 고려한 최대 패치수 결정
        self.to_tensor = transforms.ToTensor()
        self.paths = sorted(glob.glob(os.path.join(img_dir, '*.png')))

        # Precompute only patch coordinates for lazy feature computation
        self.coords = []
        for path in self.paths:
            pil_l = Image.open(path).convert('L') # 명안 채널만 남겨 grayscale변환
            arr = np.array(pil_l) # numpy배열 변환
            scores = [] # 분산기반 패치 저장 리스트
            # 이미지의 모든 영역을 순회하며 (K x K)크기 패치의 분산 계산
            for y in range(0, 512 - K + 1, K):
                for x in range(0, 512 - K + 1, K):
                    p = arr[y:y+K, x:x+K]
                    scores.append((p.var(), y, x)) # 패치 분산, 좌측 상단 좌표(y,x)저장
            scores.sort(key=lambda t: t[0], reverse=True) # 분산이 큰 기준으로 내림차순 정렬
            self.coords.append([(y, x) for (_, y, x) in scores[:self.M]]) # 정렬된 분산 리스트에서 앞에서 M개의 좌표만 저장

        # Lazy feature cache: compute features on first access
        self.feature_cache = {}

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        # If features already computed, return from cache
        if idx in self.feature_cache:
            return self.feature_cache[idx]

        path = self.paths[idx]
        pil = Image.open(path).convert('RGB')
        hr = self.to_tensor(pil)

        # up128 global
        lr128 = pil.resize((128, 128), Image.BICUBIC)
        up128 = F.interpolate(
            self.to_tensor(lr128).unsqueeze(0),
            size=(512, 512), mode='bicubic'
        ).squeeze(0)

        # aggregate patches
        lr256 = pil.resize((256, 256), Image.BICUBIC)
        acc256, acc512 = [], []
        for (y, x) in self.coords[idx]:
            y2, x2 = y//2, x//2
            p256 = lr256.crop((x2, y2, x2 + self.K//2, y2 + self.K//2))
            acc256.append(
                F.interpolate(
                    self.to_tensor(p256).unsqueeze(0),
                    size=(512, 512), mode='bicubic'
                )
            )
            p512 = pil.crop((x, y, x + self.K, y + self.K))
            acc512.append(
                F.interpolate(
                    self.to_tensor(p512).unsqueeze(0),
                    size=(512, 512), mode='bicubic'
                )
            )

        up256 = torch.mean(torch.cat(acc256, dim=0), dim=0)
        up512 = torch.mean(torch.cat(acc512, dim=0), dim=0)
        inp = torch.cat([up128, up256, up512], dim=0)

        # Store in cache for subsequent epochs
        self.feature_cache[idx] = (inp, hr)
        return inp, hr

    def clear_cache(self):
        """Epoch 단위로 캐시된 feature를 비웁니다."""
        self.feature_cache.clear()


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("device : ", device)


In [None]:
# 이미지 복원 네트워크
class MyVersionSRCNN(nn.Module):
    def __init__(self, in_channels, out_channels=3):
        super(MyVersionSRCNN, self).__init__()
        self.layers = nn.Sequential(
            # 1) Depthwise Conv 5×5 방법-파라미터, 연산량을 모두 감소
            nn.Conv2d(in_channels, in_channels, kernel_size=5, padding=2, groups=in_channels),
            nn.Conv2d(in_channels, in_channels, kernel_size=5, padding=2, groups=in_channels),
            # 2) Pointwise Conv 1×1 → 64채널
            nn.Conv2d(in_channels, 64, kernel_size=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 32, kernel_size=1),         # dimensionality reduction
            nn.ReLU(inplace=True),

            nn.Conv2d(32, out_channels, kernel_size=5, padding=2)  # reconstruction
        )

    def forward(self, x):
        # x[:, :3] 은 global up128 의 첫 3채널
        res = self.layers(x) # 잔차 학습(Residual Learning)---학습속도 계선
        return  (res + x[:, :3, :, :]).clamp(0,1)# output range [0,1]

# 선택가능 패치 개수 : M = 16384 // (K*K) = 64, so in_ch = 3 + 6*64 = 387
model1 = MyVersionSRCNN(in_channels=9).to(device)

In [None]:
# 2) Basic SRCNN (adapted for 9-channel input)
class BasicSRCNN(nn.Module):
    def __init__(self, in_channels=9, out_channels=3):
        super(BasicSRCNN, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=9, padding=4),  # feature extraction
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=1),                       # non-linear mapping
            nn.ReLU(inplace=True),
            nn.Conv2d(32, out_channels, kernel_size=5, padding=2)   # reconstruction
        )

    def forward(self, x):
        return self.net(x).clamp(0, 1)

model2 = BasicSRCNN(in_channels=9).to(device)


In [None]:
# 3) Residual SRCNN
class ResidualSRCNN(nn.Module):
    def __init__(self, in_channels=9, out_channels=3):
        super(ResidualSRCNN, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=9, padding=4),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, out_channels, kernel_size=5, padding=2)
        )

    def forward(self, x):
        res = self.net(x)
        # add global up128 (first 3 channels) as residual connection
        return (res + x[:, :3, :, :]).clamp(0, 1)

model3 = ResidualSRCNN(in_channels=9).to(device)


In [None]:
# 4) Depthwise-Separable SRCNN
class DepthwiseSRCNN(nn.Module):
    def __init__(self, in_channels=9, out_channels=3):
        super(DepthwiseSRCNN, self).__init__()
        self.net = nn.Sequential(
            # depthwise 9x9
            nn.Conv2d(in_channels, in_channels, kernel_size=9, padding=4, groups=in_channels),
            # pointwise 1x1
            nn.Conv2d(in_channels, 64, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, out_channels, kernel_size=5, padding=2)
        )

    def forward(self, x):
        return self.net(x).clamp(0, 1)

model4 = DepthwiseSRCNN(in_channels=9).to(device)


In [None]:
from torchsummary import summary
summary(model1, (9, 512, 512)) # MyVersionSRCNN
summary(model2, (9, 512, 512)) # BasicSRCNN
summary(model3, (9, 512, 512)) # ResidualSRCNN
summary(model4, (9, 512, 512)) # DepthwiseSRCNN

In [None]:
print("CPU 코어 수:", os.cpu_count())

In [None]:
# Initialize dataset
train_dataset = VariancePatchDataset(img_dir='/dataset/new_COCO2/Train',K=8)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=3, pin_memory=True, persistent_workers=True, prefetch_factor=2)
test_dataset = VariancePatchDataset(img_dir='/dataset/new_COCO2/Test', K=8)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=3, pin_memory=True, persistent_workers=True, prefetch_factor=2)


In [None]:
# Dataset shape 확인
print(f"Train dataset : {len(train_dataset)}")
print(f"Test dataset : {len(test_dataset)}")

sample_input, sample_output = train_dataset[0]
print(f"Train dataset input shape(3개의 scale의 rgb이미지): {sample_input.shape}")
print(f"Train dataset output shape(복원된 이미지): {sample_output.shape}")

sample_input, sample_output = test_dataset[0]
print(f"Test dataset input shape: {sample_input.shape}")
print(f"Test dataset output shape: {sample_output.shape}")

In [None]:
# 하이퍼파라미터 설정
learning_rate = 1e-3 # 학습률
criterion = nn.MSELoss() # loss fuction
num_epochs = 30 # number of epochs

# 1번모델========================================================
# optimizer
optimizer1 = optim.Adam(model1.parameters(), lr=learning_rate)

# shceduler(learning rate 조절)
scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=10, gamma=0.1) # 10 epoch마다 0.1배

# 2번모델==========================================================
# optimizer
optimizer2 = optim.Adam(model2.parameters(), lr=learning_rate)

# shceduler(learning rate 조절)
scheduler2 = optim.lr_scheduler.StepLR(optimizer2, step_size=10, gamma=0.1) # 10 epoch마다 0.1배

# 3번모델==========================================================
# optimizer
optimizer3 = optim.Adam(model3.parameters(), lr=learning_rate)

# shceduler(learning rate 조절)
scheduler3 = optim.lr_scheduler.StepLR(optimizer3, step_size=10, gamma=0.1) # 10 epoch마다 0.1배

# 4번모델==========================================================
# optimizer
optimizer4 = optim.Adam(model4.parameters(), lr=learning_rate)

# shceduler(learning rate 조절)
scheduler4 = optim.lr_scheduler.StepLR(optimizer4, step_size=10, gamma=0.1) # 10 epoch마다 0.1배


In [None]:
import time
from torch.cuda.amp import autocast, GradScaler

def train(model, device, train_loader, optimizer, criterion, sceduler, num_epochs, time_limit_hours=10):
    start_time = time.time()
    train_loss = []
    train_psnr = []

    for epoch in range(num_epochs):
        train_dataset.clear_cache() # epoch마다 캐시 초기화
        model.train()
        epoch_loss = 0
        epoch_psnr = 0
        num_batches = len(train_loader)
        scaler = GradScaler()
        for input_tensor, target in train_loader:
            input_tensor = input_tensor.to(device)  # [B, 9, 512, 512]
            target = target.to(device)              # [B, 3, 512, 512]

            optimizer.zero_grad()
            with autocast():
                output = model(input_tensor)
                loss = criterion(output, target)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # PSNR 계산
            psnr = compute_psnr(output, target)

            epoch_loss += loss.item()
            epoch_psnr += psnr.item()

        scheduler.step()

        avg_loss = epoch_loss / num_batches
        avg_psnr = epoch_psnr / num_batches
        train_loss.append(avg_loss)
        train_psnr.append(avg_psnr)

        print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}, Avg PSNR: {avg_psnr:.2f} dB")

        # 시간 제한 확인
        elapsed_time = time.time() - start_time
        if elapsed_time > time_limit_hours * 3600:
            print("Training stopped: 10-hour limit reached.")
            break

    return model, train_loss, train_psnr


In [None]:
print('1번 모델 훈련 : MyVersionSRCNN')
train_model1, train_loss1, train_psnr1 = train(model1, device, train_dataloader, optimizer1, criterion, scheduler1, num_epochs)


In [None]:
# 그래프로 결과 시각화
plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.plot(train_loss1)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.subplot(1,2,2)
plt.plot(train_psnr1)
plt.title('Training PSNR')
plt.xlabel('Epoch')
plt.ylabel('PSNR')
plt.show()

In [None]:
print('2번 모델 훈련 : BasicSRCNN')
train_model2, train_loss2, train_psnr2 = train(model2, device, train_dataloader, optimizer2, criterion, scheduler2, num_epochs)


In [None]:
# 그래프로 결과 시각화
plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.plot(train_loss2)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.subplot(1,2,2)
plt.plot(train_psnr2)
plt.title('Training PSNR')
plt.xlabel('Epoch')
plt.ylabel('PSNR')
plt.show()

In [None]:
print('3번 모델 훈련 : ResidualSRCNN')
train_model3, train_loss3, train_psnr3 = train(model3, device, train_dataloader, optimizer3, criterion, scheduler3, num_epochs)


In [None]:
# 그래프로 결과 시각화
plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.plot(train_loss3)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.subplot(1,2,2)
plt.plot(train_psnr3)
plt.title('Training PSNR')
plt.xlabel('Epoch')
plt.ylabel('PSNR')
plt.show()

In [None]:
print('4번 모델 훈련 : DepthwiseSRCNN')
train_model4, train_loss4, train_psnr4 = train(model4, device, train_dataloader, optimizer4, criterion, scheduler4, num_epochs)


In [None]:
# 그래프로 결과 시각화
plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.plot(train_loss4)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.subplot(1,2,2)
plt.plot(train_psnr4)
plt.title('Training PSNR')
plt.xlabel('Epoch')
plt.ylabel('PSNR')
plt.show()

In [None]:
def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    test_psnr = 0
    num_batches = len(test_loader)

    with torch.no_grad():
        for input_tensor, target in test_loader:
            input_tensor = input_tensor.to(device)
            target = target.to(device)

            output = model(input_tensor)
            loss = criterion(output, target)

            # PSNR 계산
            psnr = compute_psnr(output, target)

            test_loss += loss.item()
            test_psnr += psnr.item()

    avg_loss = test_loss / num_batches
    avg_psnr = test_psnr / num_batches

    print(f"\nTest Results:")
    print(f"Average Loss: {avg_loss:.4f}, Average PSNR: {avg_psnr:.2f} dB")
    return avg_loss, avg_psnr

print('\n모델 1 테스트 결과 : MyVersionSRCNN')
test_loss1, test_psnr1 = test(train_model1, device, test_dataloader, criterion)
print('\n모델 2 테스트 결과 : BasicSRCNN')
test_loss2, test_psnr2 = test(train_model2, device, test_dataloader, criterion)
print('\n모델 3 테스트 결과 : ResidualSRCNN')
test_loss3, test_psnr3 = test(train_model3, device, test_dataloader, criterion)
print('\n모델 4 테스트 결과 : DepthwiseSRCNN')
test_loss4, test_psnr4 = test(train_model4, device, test_dataloader, criterion)

# 모델별 테스트 결과 비교 (선택 사항)
print("\n--- 최종 테스트 결과 비교 ---")
print(f"MyVersionSRCNN: Avg Loss = {test_loss1:.4f}, Avg PSNR = {test_psnr1:.2f} dB")
print(f"BasicSRCNN:     Avg Loss = {test_loss2:.4f}, Avg PSNR = {test_psnr2:.2f} dB")
print(f"ResidualSRCNN:  Avg Loss = {test_loss3:.4f}, Avg PSNR = {test_psnr3:.2f} dB")
print(f"DepthwiseSRCNN: Avg Loss = {test_loss4:.4f}, Avg PSNR = {test_psnr4:.2f} dB")


In [None]:
import matplotlib.pyplot as plt
def display_random_test_image_reconstruction(model, device, test_dataset):
    # 랜덤 이미지 선택
    random_idx = random.randint(0, len(test_dataset) - 1)
    input_tensor, target = test_dataset[random_idx]

    # 모델 입력 형태로 변환 및 디바이스 이동
    input_tensor_device = input_tensor.unsqueeze(0).to(device) # Add batch dimension

    # 모델 예측
    model.eval()
    with torch.no_grad():
        reconstructed_output = model(input_tensor_device)

    # 이미지 시각화를 위해 CPU로 이동 및 numpy 배열로 변환
    # (C, H, W) -> (H, W, C)로 차원 변경
    input_original = target.cpu().squeeze(0).permute(1, 2, 0).numpy()
    reconstructed_image = reconstructed_output.cpu().squeeze(0).permute(1, 2, 0).numpy()

    # 시각화
    plt.figure(figsize=(12, 6))

    # 원본 이미지 출력
    plt.subplot(1, 2, 1)
    plt.imshow(input_original)
    plt.title('Original Image')
    plt.axis('off')

    # 복원 이미지 출력
    plt.subplot(1, 2, 2)
    plt.imshow(reconstructed_image)
    plt.title('Reconstructed Image')
    plt.axis('off')

    plt.show()

print("\n--- 랜덤 테스트 이미지 복원 결과 (MyVersionSRCNN) ---")
display_random_test_image_reconstruction(train_model1, device, test_dataset)

print("\n--- 랜덤 테스트 이미지 복원 결과 (BasicSRCNN) ---")
display_random_test_image_reconstruction(train_model2, device, test_dataset)

print("\n--- 랜덤 테스트 이미지 복원 결과 (ResidualSRCNN) ---")
display_random_test_image_reconstruction(train_model3, device, test_dataset)

print("\n--- 랜덤 테스트 이미지 복원 결과 (DepthwiseSRCNN) ---")
display_random_test_image_reconstruction(train_model4, device, test_dataset)

In [None]:
# 2) masked_psnr 함수
def masked_psnr(output, target, coords, K):
    """
    output, target: [B,3,512,512]
    coords: list of (y,x) for this sample
    """
    mse_map = (output - target).pow(2)      # [B,3,512,512]
    mask    = torch.ones_like(mse_map)      # same shape
    for (y, x) in coords:
        mask[:, :, y:y+K, x:x+K] = 0
    masked_mse = (mse_map * mask).sum() / mask.sum()
    return 10 * torch.log10(255.0**2 / (masked_mse ))

In [None]:
def evaluate_masked_psnr(model, device, dataloader, dataset, K):
    model.eval()
    total_masked_psnr = 0
    count = 0
    dataset.clear_cache() # evaluation 전에 캐시 초기화

    with torch.no_grad():
        for idx, (input_tensor, target) in enumerate(dataloader):
            input_tensor = input_tensor.to(device)
            target = target.to(device)

            output = model(input_tensor)

            # Batch 내 각 이미지에 대해 masked_psnr 계산
            for i in range(output.size(0)):
                # 현재 배치에서 해당 이미지의 인덱스에 해당하는 coords 가져오기
                # dataloader가 shuffle되어 있을 수 있으므로 original dataset 인덱스 필요
                # 하지만 현재 dataloader 구현에서 batch 내 index와 dataset index가 일치하지 않으므로,
                # 여기서는 간단하게 dataloader 내 batch index를 활용
                # 실제 정확한 구현을 위해서는 dataloader에서 original dataset index도 함께 반환하도록 수정 필요
                # 임시 방편으로 현재 batch index + (dataloader batch index * batch_size) 사용
                original_dataset_index = idx * dataloader.batch_size + i
                if original_dataset_index < len(dataset): # 데이터셋 크기 초과 방지
                  coords = dataset.coords[original_dataset_index]
                  psnr = masked_psnr(output[i].unsqueeze(0) * 255.0, target[i].unsqueeze(0) * 255.0, coords, K) # PSNR 계산은 보통 0-255 범위에서 수행
                  total_masked_psnr += psnr.item()
                  count += 1
                else:
                    print(f"Warning: Original dataset index {original_dataset_index} out of bounds.")


    avg_masked_psnr = total_masked_psnr / count if count > 0 else 0
    return avg_masked_psnr

# MyVersionSRCNN 모델에 대한 masked PSNR 계산
avg_masked_psnr1 = evaluate_masked_psnr(train_model1, device, test_dataloader, test_dataset, K=16)
print(f"MyVersionSRCNN Masked PSNR on test set: {avg_masked_psnr1:.2f} dB")

# BasicSRCNN 모델에 대한 masked PSNR 계산
avg_masked_psnr2 = evaluate_masked_psnr(train_model2, device, test_dataloader, test_dataset, K=16)
print(f"BasicSRCNN Masked PSNR on test set: {avg_masked_psnr2:.2f} dB")

# ResidualSRCNN 모델에 대한 masked PSNR 계산
avg_masked_psnr3 = evaluate_masked_psnr(train_model3, device, test_dataloader, test_dataset, K=16)
print(f"ResidualSRCNN Masked PSNR on test set: {avg_masked_psnr3:.2f} dB")

# DepthwiseSRCNN 모델에 대한 masked PSNR 계산
avg_masked_psnr4 = evaluate_masked_psnr(train_model4, device, test_dataloader, test_dataset, K=16)
print(f"DepthwiseSRCNN Masked PSNR on test set: {avg_masked_psnr4:.2f} dB")



In [None]:
import matplotlib.pyplot as plt
import numpy as np
def plot_reconstruction_and_inputs(model, device, dataset, idx, K):
    """
    주어진 인덱스의 테스트 데이터셋 샘플에 대해 다음을 시각화합니다:
    1. 원본 512x512 이미지
    2. 모델 입력의 각 컴포넌트 (up128, up256, up512)
    3. 분산 기반으로 선택된 512x512 이미지 패치만 (나머지는 검정)
    4. 분산 기반으로 선택된 256x256 이미지 패치만 (나머지는 검정)
    5. 모델에 의해 복원된 이미지

    Args:
        model (nn.Module): 학습된 모델.
        device (torch.device): 모델이 로드된 장치 ('cuda' 또는 'cpu').
        dataset (Dataset): 테스트 데이터셋 (VariancePatchDataset).
        idx (int): 시각화할 샘플의 인덱스.
        K (int): 패치 크기.
    """
    dataset.clear_cache() # 캐시 초기화

    # 샘플 데이터 가져오기
    input_tensor, target = dataset[idx]
    image_path = dataset.paths[idx]
    coords = dataset.coords[idx]

    # 모델 예측
    model.eval()
    with torch.no_grad():
        input_tensor_device = input_tensor.unsqueeze(0).to(device)
        reconstructed_output = model(input_tensor_device).squeeze(0).cpu()

    # 이미지 시각화를 위해 CPU로 이동 및 numpy 배열 변환 (H, W, C)
    original_image = target.permute(1, 2, 0).cpu().numpy()
    reconstructed_image_np = reconstructed_output.permute(1, 2, 0).numpy()

    # 입력 컴포넌트 준비
    up128_image_np = input_tensor[:3].permute(1, 2, 0).cpu().numpy()
    up256_image_np = input_tensor[3:6].permute(1, 2, 0).cpu().numpy()
    up512_image_np = input_tensor[6:].permute(1, 2, 0).cpu().numpy()

    # 512x512 및 256x256 패치 이미지 준비
    pil_512 = Image.open(image_path).convert('RGB')
    img_512_np = np.array(pil_512)
    masked_512 = np.zeros_like(img_512_np)

    for (y, x) in coords:
        patch_512 = img_512_np[y:y+K, x:x+K]
        masked_512[y:y+K, x:x+K] = patch_512

    pil_256 = pil_512.resize((256, 256), Image.BICUBIC)
    img_256_np = np.array(pil_256)
    masked_256 = np.zeros_like(img_256_np)

    coords_256 = [(y//2, x//2) for (y, x) in coords]
    for (y2, x2) in coords_256:
        patch_256 = img_256_np[y2:y2+K//2, x2:x2+K//2]
        masked_256[y2:y2+K//2, x2:x2+K//2] = patch_256


    # 시각화
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))

    axes[0, 0].imshow(original_image)
    axes[0, 0].set_title("Original 512x512")
    axes[0, 0].axis('off')

    axes[0, 1].imshow(up128_image_np)
    axes[0, 1].set_title("Input: Up128")
    axes[0, 1].axis('off')

    axes[0, 2].imshow(up256_image_np)
    axes[0, 2].set_title("Input: Up256 (Aggregated Patches)")
    axes[0, 2].axis('off')

    axes[1, 0].imshow(up512_image_np)
    axes[1, 0].set_title("Input: Up512 (Aggregated Patches)")
    axes[1, 0].axis('off')

    axes[1, 1].imshow(masked_512)
    axes[1, 1].set_title("Selected 512x512 Patches Only")
    axes[1, 1].axis('off')

    axes[1, 2].imshow(masked_256)
    axes[1, 2].set_title("Selected 256x256 Patches Only")
    axes[1, 2].axis('off')

    # 복원된 이미지를 따로 큰 제목으로 표시
    plt.figure(figsize=(6, 6))
    plt.imshow(reconstructed_image_np)
    plt.title("Reconstructed Image")
    plt.axis('off')
    plt.show()


# 각 모델에 대해 랜덤 테스트 이미지 하나를 선택하여 시각화
random_test_idx = random.randint(0, len(test_dataset) - 1)
patch_size = test_dataset.K

print(f"\n--- 시각화 샘플 인덱스: {random_test_idx} ---")

print("\n--- MyVersionSRCNN 결과 ---")
plot_reconstruction_and_inputs(train_model1, device, test_dataset, random_test_idx, patch_size)

print("\n--- BasicSRCNN 결과 ---")
plot_reconstruction_and_inputs(train_model2, device, test_dataset, random_test_idx, patch_size)

print("\n--- ResidualSRCNN 결과 ---")
plot_reconstruction_and_inputs(train_model3, device, test_dataset, random_test_idx, patch_size)

print("\n--- DepthwiseSRCNN 결과 ---")
plot_reconstruction_and_inputs(train_model4, device, test_dataset, random_test_idx, patch_size)
