# Patch-based 4K Super-Resolution Training with Mixed Precision

This notebook implements a **patch-based training strategy** for 16× super-resolution (256×256 → 4096×4096).

## Training Strategy:
- **Train on patches**: Extract small patches from large images
- **LR patch size**: 64×64 or 128×128
- **HR patch size**: 1024×1024 or 2048×2048 (16× larger)
- **Inference**: Sliding window over full image, stitch results

## Advantages:
- **Very low memory**: Can train on 4GB GPU
- **Fast iteration**: Small patches train quickly
- **Scalable**: Works for any resolution
- **Mixed precision**: Further reduces memory

---

## 1. Import Dependencies

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os
from tqdm import tqdm
import cv2
import random

## 2. Device Configuration

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB')

## 3. Patch-based Dataset

### Key Features:
- Randomly extracts patches from 4K images during training
- LR patch: 64×64, HR patch: 1024×1024 (16× scale)
- Each epoch sees different random patches
- Data augmentation: random flip, rotation

In [None]:
class PatchSRDataset(Dataset):
    """基于分块的超分辨率数据集
    
    从4K图像中随机裁剪patch进行训练
    
    Args:
        hr_dir: 4096×4096 HR图像目录
        lr_patch_size: LR patch大小（如64或128）
        scale_factor: 放大倍数（16）
        patches_per_image: 每张图像提取的patch数量
        augment: 是否使用数据增强
    """
    def __init__(self, hr_dir, lr_patch_size=64, scale_factor=16, 
                 patches_per_image=10, augment=True):
        self.hr_dir = Path(hr_dir)
        self.lr_patch_size = lr_patch_size
        self.hr_patch_size = lr_patch_size * scale_factor
        self.scale_factor = scale_factor
        self.patches_per_image = patches_per_image
        self.augment = augment
        
        # 获取所有HR图片
        self.hr_images = sorted(list(self.hr_dir.glob('*.png')))
        
        print(f"数据集信息:")
        print(f"  图像数量: {len(self.hr_images)}")
        print(f"  LR patch: {lr_patch_size}×{lr_patch_size}")
        print(f"  HR patch: {self.hr_patch_size}×{self.hr_patch_size}")
        print(f"  每张图像patch数: {patches_per_image}")
        print(f"  总patch数/epoch: {len(self.hr_images) * patches_per_image}")
    
    def __len__(self):
        return len(self.hr_images) * self.patches_per_image
    
    def augment_patch(self, lr_patch, hr_patch):
        """数据增强：随机翻转和旋转"""
        # 随机水平翻转
        if random.random() > 0.5:
            lr_patch = np.fliplr(lr_patch)
            hr_patch = np.fliplr(hr_patch)
        
        # 随机垂直翻转
        if random.random() > 0.5:
            lr_patch = np.flipud(lr_patch)
            hr_patch = np.flipud(hr_patch)
        
        # 随机旋转90度
        k = random.randint(0, 3)
        if k > 0:
            lr_patch = np.rot90(lr_patch, k)
            hr_patch = np.rot90(hr_patch, k)
        
        return lr_patch.copy(), hr_patch.copy()
    
    def __getitem__(self, idx):
        # 确定是哪张图像
        img_idx = idx // self.patches_per_image
        
        # 读取HR图像（4096×4096）
        hr_img = cv2.imread(str(self.hr_images[img_idx]))
        hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)
        
        # 随机裁剪HR patch
        h, w = hr_img.shape[:2]
        
        # 确保patch在图像范围内
        max_y = h - self.hr_patch_size
        max_x = w - self.hr_patch_size
        
        if max_y <= 0 or max_x <= 0:
            # 如果图像太小，直接resize
            hr_patch = cv2.resize(hr_img, (self.hr_patch_size, self.hr_patch_size))
            y, x = 0, 0
        else:
            y = random.randint(0, max_y)
            x = random.randint(0, max_x)
            hr_patch = hr_img[y:y+self.hr_patch_size, x:x+self.hr_patch_size]
        
        # 生成对应的LR patch（下采样）
        lr_patch = cv2.resize(hr_patch, (self.lr_patch_size, self.lr_patch_size), 
                             interpolation=cv2.INTER_CUBIC)
        
        # 数据增强
        if self.augment:
            lr_patch, hr_patch = self.augment_patch(lr_patch, hr_patch)
        
        # 转换为tensor，归一化到[0, 1]
        lr_tensor = torch.from_numpy(lr_patch.transpose(2, 0, 1)).float() / 255.0
        hr_tensor = torch.from_numpy(hr_patch.transpose(2, 0, 1)).float() / 255.0
        
        return lr_tensor, hr_tensor

## 4. Efficient U-Net Model for 16× SR

### Architecture:
- Designed for 16× super-resolution (64×64 → 1024×1024)
- Uses progressive upsampling (4 stages of 2× each)
- Lightweight with ~10M parameters
- BatchNorm for training stability

In [None]:
class DoubleConv(nn.Module):
    """双卷积块"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """下采样块"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """上采样块"""
    def __init__(self, in_channels, out_channels, skip_channels=None):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        
        if skip_channels is not None:
            self.conv = DoubleConv(in_channels // 2 + skip_channels, out_channels)
        else:
            self.conv = DoubleConv(in_channels // 2, out_channels)
        
        self.has_skip = skip_channels is not None
    
    def forward(self, x1, x2=None):
        x1 = self.up(x1)
        if self.has_skip and x2 is not None:
            x = torch.cat([x2, x1], dim=1)
        else:
            x = x1
        return self.conv(x)


class UNetPatch16x(nn.Module):
    """基于Patch的16×超分辨率U-Net
    
    输入: 64×64 (或128×128)
    输出: 1024×1024 (或2048×2048)
    """
    def __init__(self, n_channels=3, n_classes=3, base_channels=48):
        super(UNetPatch16x, self).__init__()
        
        # 编码器（64 -> 4）
        self.inc = DoubleConv(n_channels, base_channels)           # 64×64
        self.down1 = Down(base_channels, base_channels * 2)        # 32×32
        self.down2 = Down(base_channels * 2, base_channels * 4)    # 16×16
        self.down3 = Down(base_channels * 4, base_channels * 8)    # 8×8
        self.down4 = Down(base_channels * 8, base_channels * 16)   # 4×4 (bottleneck)
        
        # 解码器（4 -> 64，带skip连接）
        self.up1 = Up(base_channels * 16, base_channels * 8, skip_channels=base_channels * 8)   # 8×8
        self.up2 = Up(base_channels * 8, base_channels * 4, skip_channels=base_channels * 4)    # 16×16
        self.up3 = Up(base_channels * 4, base_channels * 2, skip_channels=base_channels * 2)    # 32×32
        self.up4 = Up(base_channels * 2, base_channels, skip_channels=base_channels)            # 64×64
        
        # 额外的上采样层（64 -> 1024，16×放大）
        # 64 -> 128 -> 256 -> 512 -> 1024
        self.up5 = Up(base_channels, base_channels, skip_channels=None)      # 128×128
        self.up6 = Up(base_channels, base_channels, skip_channels=None)      # 256×256
        self.up7 = Up(base_channels, base_channels, skip_channels=None)      # 512×512
        self.up8 = Up(base_channels, base_channels, skip_channels=None)      # 1024×1024
        
        # 输出层
        self.outc = nn.Conv2d(base_channels, n_classes, kernel_size=1)
    
    def forward(self, x):
        # 编码
        x1 = self.inc(x)       # 48, 64×64
        x2 = self.down1(x1)    # 96, 32×32
        x3 = self.down2(x2)    # 192, 16×16
        x4 = self.down3(x3)    # 384, 8×8
        x5 = self.down4(x4)    # 768, 4×4
        
        # 解码（带skip连接）
        x = self.up1(x5, x4)   # 384, 8×8
        x = self.up2(x, x3)    # 192, 16×16
        x = self.up3(x, x2)    # 96, 32×32
        x = self.up4(x, x1)    # 48, 64×64
        
        # 额外上采样到16×
        x = self.up5(x)        # 48, 128×128
        x = self.up6(x)        # 48, 256×256
        x = self.up7(x)        # 48, 512×512
        x = self.up8(x)        # 48, 1024×1024
        
        # 输出
        return self.outc(x)    # 3, 1024×1024

## 5. Training Configuration

### Memory-Efficient Settings:
- **LR patch**: 64×64 (very small)
- **HR patch**: 1024×1024 (16× scale)
- **Batch size**: 4-8 (fits easily in GPU)
- **Mixed precision**: Enabled
- **Patches per image**: 10 (more variety per epoch)

In [None]:
# 训练配置
LR_PATCH_SIZE = 64          # LR patch大小（64×64或128×128）
SCALE_FACTOR = 16           # 放大倍数
HR_PATCH_SIZE = LR_PATCH_SIZE * SCALE_FACTOR  # 1024×1024

BATCH_SIZE = 4              # batch size（可以设置更大，如8或16）
LEARNING_RATE = 2e-4        # 学习率
NUM_EPOCHS = 50             # 训练轮数
PATCHES_PER_IMAGE = 10      # 每张图像提取的patch数

BASE_CHANNELS = 48          # 模型基础通道数（可调整：32/48/64）
USE_MIXED_PRECISION = True  # 混合精度训练
TRAIN_SPLIT = 0.9           # 训练集比例

# 数据路径
HR_DIR = './dataset_4k/high_resolution'
CHECKPOINT_DIR = Path('./checkpoints_patch')
CHECKPOINT_DIR.mkdir(exist_ok=True)

print("训练配置:")
print(f"  LR patch大小: {LR_PATCH_SIZE}×{LR_PATCH_SIZE}")
print(f"  HR patch大小: {HR_PATCH_SIZE}×{HR_PATCH_SIZE}")
print(f"  放大倍数: {SCALE_FACTOR}×")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  训练轮数: {NUM_EPOCHS}")
print(f"  每张图像patches: {PATCHES_PER_IMAGE}")
print(f"  混合精度: {'启用' if USE_MIXED_PRECISION else '禁用'}")
print(f"  模型基础通道数: {BASE_CHANNELS}")

## 6. Data Loading

In [None]:
# 创建数据集
full_dataset = PatchSRDataset(
    hr_dir=HR_DIR,
    lr_patch_size=LR_PATCH_SIZE,
    scale_factor=SCALE_FACTOR,
    patches_per_image=PATCHES_PER_IMAGE,
    augment=True
)

# 划分训练集和验证集
train_size = int(TRAIN_SPLIT * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size]
)

# 数据加载器
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,  # 可以使用多进程加载
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"\n训练集patches: {train_size}")
print(f"验证集patches: {val_size}")
print(f"训练批次数: {len(train_loader)}")

## 7. Visualize Sample Patches

In [None]:
# 可视化一些patch样本
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

for i in range(5):
    lr_patch, hr_patch = full_dataset[i]
    
    # 转换为numpy
    lr_np = lr_patch.numpy().transpose(1, 2, 0)
    hr_np = hr_patch.numpy().transpose(1, 2, 0)
    
    # 显示LR patch
    axes[0, i].imshow(lr_np)
    axes[0, i].set_title(f'LR {LR_PATCH_SIZE}×{LR_PATCH_SIZE}')
    axes[0, i].axis('off')
    
    # 显示HR patch（下采样显示）
    hr_display = cv2.resize(hr_np, (256, 256), interpolation=cv2.INTER_AREA)
    axes[1, i].imshow(hr_display)
    axes[1, i].set_title(f'HR {HR_PATCH_SIZE}×{HR_PATCH_SIZE}')
    axes[1, i].axis('off')

plt.tight_layout()
plt.savefig(CHECKPOINT_DIR / 'patch_samples.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"样本已保存至: {CHECKPOINT_DIR / 'patch_samples.png'}")

## 8. Model Initialization

In [None]:
# 初始化模型
model = UNetPatch16x(n_channels=3, n_classes=3, base_channels=BASE_CHANNELS).to(device)

# 统计参数
total_params = sum(p.numel() for p in model.parameters())
print(f"\n模型参数总数: {total_params:,}")

# 损失函数和优化器
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.5)

# 混合精度scaler
scaler = GradScaler() if USE_MIXED_PRECISION else None

print(f"优化器: Adam")
print(f"学习率: {LEARNING_RATE}")
print(f"损失函数: L1 Loss")

## 9. Training and Validation Functions

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, scaler, device, use_amp=True):
    """训练一个epoch"""
    model.train()
    running_loss = 0.0
    
    pbar = tqdm(train_loader, desc='训练')
    for lr_patches, hr_patches in pbar:
        lr_patches = lr_patches.to(device)
        hr_patches = hr_patches.to(device)
        
        optimizer.zero_grad()
        
        # 混合精度训练
        if use_amp and scaler is not None:
            with autocast():
                outputs = model(lr_patches)
                loss = criterion(outputs, hr_patches)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(lr_patches)
            loss = criterion(outputs, hr_patches)
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return running_loss / len(train_loader)


def validate(model, val_loader, criterion, device, use_amp=True):
    """验证"""
    model.eval()
    running_loss = 0.0
    
    with torch.no_grad():
        for lr_patches, hr_patches in val_loader:
            lr_patches = lr_patches.to(device)
            hr_patches = hr_patches.to(device)
            
            if use_amp:
                with autocast():
                    outputs = model(lr_patches)
                    loss = criterion(outputs, hr_patches)
            else:
                outputs = model(lr_patches)
                loss = criterion(outputs, hr_patches)
            
            running_loss += loss.item()
    
    return running_loss / len(val_loader)

## 10. Training Loop

In [None]:
# 训练历史
history = {'train_loss': [], 'val_loss': []}
best_val_loss = float('inf')

print("\n开始训练...\n")

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 50)
    
    # 训练
    train_loss = train_epoch(model, train_loader, criterion, optimizer, scaler, device, USE_MIXED_PRECISION)
    
    # 验证
    val_loss = validate(model, val_loader, criterion, device, USE_MIXED_PRECISION)
    
    # 记录
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    
    # 更新学习率
    scheduler.step()
    
    print(f"\n训练损失: {train_loss:.6f}")
    print(f"验证损失: {val_loss:.6f}")
    print(f"学习率: {scheduler.get_last_lr()[0]:.6f}")
    
    # 保存最佳模型
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_loss,
            'config': {
                'lr_patch_size': LR_PATCH_SIZE,
                'scale_factor': SCALE_FACTOR,
                'base_channels': BASE_CHANNELS
            }
        }, CHECKPOINT_DIR / 'best_model.pth')
        print(f"✓ 保存最佳模型 (验证损失: {val_loss:.6f})")
    
    # 定期保存
    if (epoch + 1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_loss,
        }, CHECKPOINT_DIR / f'checkpoint_epoch_{epoch+1}.pth')
        print(f"✓ 保存检查点: epoch_{epoch+1}")

print("\n训练完成！")
print(f"最佳验证损失: {best_val_loss:.6f}")

## 11. Training Curve

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], label='Training Loss', marker='o')
plt.plot(history['val_loss'], label='Validation Loss', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Loss (L1)')
plt.title('Patch-based Training Curve')
plt.legend()
plt.grid(True)
plt.savefig(CHECKPOINT_DIR / 'training_curve.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"训练曲线已保存: {CHECKPOINT_DIR / 'training_curve.png'}")

## 12. Inference with Sliding Window

### Sliding Window Strategy:
- Divide 256×256 LR image into overlapping 64×64 patches
- Process each patch → 1024×1024 HR patch
- Stitch all HR patches → 4096×4096 final image
- Use overlap and blending to avoid seams

In [None]:
def sliding_window_inference(model, lr_img_256, patch_size=64, overlap=16, device='cuda'):
    """
    使用滑动窗口对256×256图像进行16×超分辨率
    
    Args:
        model: 训练好的模型
        lr_img_256: 256×256的LR图像 (torch.Tensor, shape: 1,3,256,256)
        patch_size: patch大小（64）
        overlap: patch之间的重叠（16，用于平滑拼接）
        device: 设备
    
    Returns:
        sr_img_4k: 4096×4096的SR图像 (torch.Tensor)
    """
    model.eval()
    
    _, _, h, w = lr_img_256.shape
    stride = patch_size - overlap
    
    # 计算需要多少个patch
    num_patches_h = (h - overlap) // stride
    num_patches_w = (w - overlap) // stride
    
    # 输出图像大小
    out_h = num_patches_h * patch_size * 16
    out_w = num_patches_w * patch_size * 16
    
    # 创建输出canvas
    sr_img = torch.zeros(1, 3, out_h, out_w).to(device)
    weight_map = torch.zeros(1, 1, out_h, out_w).to(device)
    
    with torch.no_grad():
        with autocast():
            for i in range(num_patches_h):
                for j in range(num_patches_w):
                    # 提取LR patch
                    y = i * stride
                    x = j * stride
                    lr_patch = lr_img_256[:, :, y:y+patch_size, x:x+patch_size]
                    
                    # 超分辨率
                    sr_patch = model(lr_patch.to(device))
                    
                    # 放置到输出图像
                    out_y = i * stride * 16
                    out_x = j * stride * 16
                    sr_img[:, :, out_y:out_y+patch_size*16, out_x:out_x+patch_size*16] += sr_patch
                    weight_map[:, :, out_y:out_y+patch_size*16, out_x:out_x+patch_size*16] += 1
    
    # 平均重叠区域
    sr_img = sr_img / weight_map
    
    return sr_img

## 13. Test Inference on Full Image

In [None]:
# 加载最佳模型
checkpoint = torch.load(CHECKPOINT_DIR / 'best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"已加载最佳模型 (Epoch {checkpoint['epoch']+1}, 验证损失: {checkpoint['val_loss']:.6f})")

# 读取测试图像
test_hr_path = list(Path(HR_DIR).glob('*.png'))[0]
test_hr_4k = cv2.imread(str(test_hr_path))
test_hr_4k = cv2.cvtColor(test_hr_4k, cv2.COLOR_BGR2RGB)

# 生成256×256的LR图像
test_lr_256 = cv2.resize(test_hr_4k, (256, 256), interpolation=cv2.INTER_CUBIC)
lr_tensor = torch.from_numpy(test_lr_256.transpose(2, 0, 1)).float() / 255.0
lr_tensor = lr_tensor.unsqueeze(0)

print(f"\n输入图像: {lr_tensor.shape}")
print("开始滑动窗口推理...")

# 滑动窗口推理
sr_tensor = sliding_window_inference(model, lr_tensor, patch_size=LR_PATCH_SIZE, 
                                     overlap=16, device=device)

print(f"输出图像: {sr_tensor.shape}")

# 可视化结果
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# LR输入
axes[0].imshow(test_lr_256)
axes[0].set_title('Input (256×256)')
axes[0].axis('off')

# SR输出（下采样显示）
sr_np = sr_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0)
sr_np = np.clip(sr_np, 0, 1)
sr_display = cv2.resize(sr_np, (512, 512), interpolation=cv2.INTER_AREA)
axes[1].imshow(sr_display)
axes[1].set_title(f'SR Output ({sr_tensor.shape[2]}×{sr_tensor.shape[3]})')
axes[1].axis('off')

# Ground Truth（下采样显示）
hr_np = test_hr_4k / 255.0
hr_display = cv2.resize(hr_np, (512, 512), interpolation=cv2.INTER_AREA)
axes[2].imshow(hr_display)
axes[2].set_title('Ground Truth (4096×4096)')
axes[2].axis('off')

plt.tight_layout()
plt.savefig(CHECKPOINT_DIR / 'inference_result.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\n推理结果已保存: {CHECKPOINT_DIR / 'inference_result.png'}")

## 14. Evaluation Metrics

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

def evaluate_full_images(model, hr_dir, num_samples=5, patch_size=64, device='cuda'):
    """在完整图像上评估模型"""
    hr_images = sorted(list(Path(hr_dir).glob('*.png')))[:num_samples]
    
    psnr_scores = []
    ssim_scores = []
    
    for hr_path in tqdm(hr_images, desc='评估'):
        # 读取HR图像
        hr_4k = cv2.imread(str(hr_path))
        hr_4k = cv2.cvtColor(hr_4k, cv2.COLOR_BGR2RGB)
        
        # 生成LR图像
        lr_256 = cv2.resize(hr_4k, (256, 256), interpolation=cv2.INTER_CUBIC)
        lr_tensor = torch.from_numpy(lr_256.transpose(2, 0, 1)).float() / 255.0
        lr_tensor = lr_tensor.unsqueeze(0)
        
        # 推理
        sr_tensor = sliding_window_inference(model, lr_tensor, patch_size=patch_size, 
                                            overlap=16, device=device)
        
        # 转换为numpy
        sr_np = sr_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0)
        sr_np = np.clip(sr_np, 0, 1)
        hr_np = hr_4k / 255.0
        
        # 裁剪到相同大小（如果有差异）
        min_h = min(sr_np.shape[0], hr_np.shape[0])
        min_w = min(sr_np.shape[1], hr_np.shape[1])
        sr_np = sr_np[:min_h, :min_w]
        hr_np = hr_np[:min_h, :min_w]
        
        # 下采样计算指标
        sr_small = cv2.resize(sr_np, (1024, 1024), interpolation=cv2.INTER_AREA)
        hr_small = cv2.resize(hr_np, (1024, 1024), interpolation=cv2.INTER_AREA)
        
        # 计算指标
        psnr_score = psnr(hr_small, sr_small, data_range=1.0)
        ssim_score = ssim(hr_small, sr_small, data_range=1.0, channel_axis=2)
        
        psnr_scores.append(psnr_score)
        ssim_scores.append(ssim_score)
    
    print(f"\n=== 评估结果 ===")
    print(f"平均PSNR: {np.mean(psnr_scores):.2f} dB")
    print(f"平均SSIM: {np.mean(ssim_scores):.4f}")
    
    return np.mean(psnr_scores), np.mean(ssim_scores)

# 运行评估
avg_psnr, avg_ssim = evaluate_full_images(model, HR_DIR, num_samples=5, 
                                          patch_size=LR_PATCH_SIZE, device=device)

## 15. Training Summary

In [None]:
summary = f"""
基于Patch的4K超分辨率训练总结
{'='*60}

训练策略: Patch-based Training
混合精度: {'启用' if USE_MIXED_PRECISION else '禁用'}

数据配置:
  LR patch大小: {LR_PATCH_SIZE}×{LR_PATCH_SIZE}
  HR patch大小: {HR_PATCH_SIZE}×{HR_PATCH_SIZE}
  缩放倍数: {SCALE_FACTOR}×
  每张图像patches: {PATCHES_PER_IMAGE}

模型配置:
  架构: UNetPatch16x
  基础通道数: {BASE_CHANNELS}
  参数量: {total_params:,}

训练配置:
  Batch size: {BATCH_SIZE}
  训练轮数: {NUM_EPOCHS}
  初始学习率: {LEARNING_RATE}
  优化器: Adam
  损失函数: L1 Loss

训练结果:
  最佳验证损失: {best_val_loss:.6f}
  平均PSNR: {avg_psnr:.2f} dB
  平均SSIM: {avg_ssim:.4f}

模型保存位置: {CHECKPOINT_DIR / 'best_model.pth'}

推理方法:
  使用滑动窗口（overlap=16）拼接完整4K图像
"""

print(summary)

with open(CHECKPOINT_DIR / 'training_summary.txt', 'w', encoding='utf-8') as f:
    f.write(summary)

print(f"\n总结已保存至: {CHECKPOINT_DIR / 'training_summary.txt'}")