# Patch-based 16× Super-Resolution (32×32 → 512×512)

## 方案1: 直接16×放大

### 训练策略:
- **LR patch**: 32×32
- **HR patch**: 512×512
- **Scale**: 16×
- **优点**: 一个模型搞定，推理简单
- **缺点**: patch较小，感受野有限

### 推理流程 (256×256 → 4096×4096):
```
256×256 LR图像
  ↓ 切成 8×8 = 64个 32×32 patches
  ↓ 每个patch通过16×模型 (32→512)
  ↓ 拼接64个 512×512 patches
  ↓ 得到 4096×4096 HR图像
```

### 显存占用估算:
- Batch=4: 4×3×512×512 ≈ 12MB (输出)
- 总显存: ~1-2GB

---

## 1. Import Dependencies

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os
from tqdm import tqdm
import cv2
import random

## 2. Device Configuration

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB')

## 3. Patch Dataset (32×32 → 512×512)

In [None]:
class PatchSRDataset16x(Dataset):
    """32×32 → 512×512 (16×) patch数据集"""
    def __init__(self, hr_dir, lr_patch_size=32, scale_factor=16, 
                 patches_per_image=20, augment=True):
        self.hr_dir = Path(hr_dir)
        self.lr_patch_size = lr_patch_size
        self.hr_patch_size = lr_patch_size * scale_factor
        self.scale_factor = scale_factor
        self.patches_per_image = patches_per_image
        self.augment = augment
        
        self.hr_images = sorted(list(self.hr_dir.glob('*.png')))
        
        print(f"数据集信息:")
        print(f"  图像数量: {len(self.hr_images)}")
        print(f"  LR patch: {lr_patch_size}×{lr_patch_size}")
        print(f"  HR patch: {self.hr_patch_size}×{self.hr_patch_size}")
        print(f"  缩放倍数: {scale_factor}×")
        print(f"  每张图像patch数: {patches_per_image}")
        print(f"  总patch数/epoch: {len(self.hr_images) * patches_per_image}")
    
    def __len__(self):
        return len(self.hr_images) * self.patches_per_image
    
    def augment_patch(self, lr_patch, hr_patch):
        """数据增强"""
        if random.random() > 0.5:
            lr_patch = np.fliplr(lr_patch)
            hr_patch = np.fliplr(hr_patch)
        
        if random.random() > 0.5:
            lr_patch = np.flipud(lr_patch)
            hr_patch = np.flipud(hr_patch)
        
        k = random.randint(0, 3)
        if k > 0:
            lr_patch = np.rot90(lr_patch, k)
            hr_patch = np.rot90(hr_patch, k)
        
        return lr_patch.copy(), hr_patch.copy()
    
    def __getitem__(self, idx):
        img_idx = idx // self.patches_per_image
        
        hr_img = cv2.imread(str(self.hr_images[img_idx]))
        hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)
        
        h, w = hr_img.shape[:2]
        max_y = h - self.hr_patch_size
        max_x = w - self.hr_patch_size
        
        if max_y <= 0 or max_x <= 0:
            hr_patch = cv2.resize(hr_img, (self.hr_patch_size, self.hr_patch_size))
        else:
            y = random.randint(0, max_y)
            x = random.randint(0, max_x)
            hr_patch = hr_img[y:y+self.hr_patch_size, x:x+self.hr_patch_size]
        
        lr_patch = cv2.resize(hr_patch, (self.lr_patch_size, self.lr_patch_size), 
                             interpolation=cv2.INTER_CUBIC)
        
        if self.augment:
            lr_patch, hr_patch = self.augment_patch(lr_patch, hr_patch)
        
        lr_tensor = torch.from_numpy(lr_patch.transpose(2, 0, 1)).float() / 255.0
        hr_tensor = torch.from_numpy(hr_patch.transpose(2, 0, 1)).float() / 255.0
        
        return lr_tensor, hr_tensor

## 4. Compact 16× U-Net Model

轻量级设计：
- 更少的下采样层（32太小，不能下采样太多）
- 减少通道数
- 移除BatchNorm以节省显存

In [None]:
class ResidualBlock(nn.Module):
    """残差块（无BatchNorm）"""
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, padding=1)
        )
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        return self.relu(x + self.conv(x))


class UNet16xSmall(nn.Module):
    """超轻量16×SR模型 (32×32 → 512×512)
    
    设计原则：
    - 输入太小(32×32)，只下采样2次到8×8
    - 使用残差块增强特征
    - 移除BatchNorm节省显存
    - 渐进上采样16× (8→16→32→64→128→256→512)
    """
    def __init__(self, n_channels=3, n_classes=3, base_channels=32):
        super().__init__()
        
        # 初始特征提取
        self.inc = nn.Sequential(
            nn.Conv2d(n_channels, base_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        # 编码器 (32 → 16 → 8)
        self.down1 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels*2, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            ResidualBlock(base_channels*2)
        )
        
        self.down2 = nn.Sequential(
            nn.Conv2d(base_channels*2, base_channels*4, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            ResidualBlock(base_channels*4)
        )
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            ResidualBlock(base_channels*4),
            ResidualBlock(base_channels*4)
        )
        
        # 上采样到512×512 (8 → 16 → 32 → 64 → 128 → 256 → 512)
        # 总共需要上采样 512/8 = 64× = 2^6次
        self.up_blocks = nn.ModuleList([
            self._make_up_block(base_channels*4, base_channels*4),  # 8→16
            self._make_up_block(base_channels*4, base_channels*2),  # 16→32
            self._make_up_block(base_channels*2, base_channels*2),  # 32→64
            self._make_up_block(base_channels*2, base_channels),    # 64→128
            self._make_up_block(base_channels, base_channels),      # 128→256
            self._make_up_block(base_channels, base_channels),      # 256→512
        ])
        
        # 输出层
        self.outc = nn.Conv2d(base_channels, n_classes, 1)
    
    def _make_up_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # 编码
        x = self.inc(x)         # 32, 32×32
        x = self.down1(x)       # 64, 16×16
        x = self.down2(x)       # 128, 8×8
        
        # Bottleneck
        x = self.bottleneck(x)  # 128, 8×8
        
        # 渐进上采样
        for up in self.up_blocks:
            x = up(x)
        
        # 输出
        return self.outc(x)     # 3, 512×512

## 5. Training Configuration

In [None]:
# 训练配置
LR_PATCH_SIZE = 32          # LR patch: 32×32
HR_PATCH_SIZE = 512         # HR patch: 512×512
SCALE_FACTOR = 16           # 16×放大

BATCH_SIZE = 8              # 可以设置更大
LEARNING_RATE = 2e-4
NUM_EPOCHS = 80             # 更多epochs补偿小patch
PATCHES_PER_IMAGE = 20      # 每张图多提取一些patches

BASE_CHANNELS = 32          # 基础通道数
USE_MIXED_PRECISION = True
TRAIN_SPLIT = 0.9

HR_DIR = './dataset_4k/high_resolution'
CHECKPOINT_DIR = Path('./checkpoints_16x_small')
CHECKPOINT_DIR.mkdir(exist_ok=True)

print("=== 方案1: 直接16×放大 ===")
print(f"LR patch: {LR_PATCH_SIZE}×{LR_PATCH_SIZE}")
print(f"HR patch: {HR_PATCH_SIZE}×{HR_PATCH_SIZE}")
print(f"放大倍数: {SCALE_FACTOR}×")
print(f"Batch size: {BATCH_SIZE}")
print(f"训练轮数: {NUM_EPOCHS}")
print(f"每张图patches: {PATCHES_PER_IMAGE}")

## 6. Data Loading

In [None]:
full_dataset = PatchSRDataset16x(
    hr_dir=HR_DIR,
    lr_patch_size=LR_PATCH_SIZE,
    scale_factor=SCALE_FACTOR,
    patches_per_image=PATCHES_PER_IMAGE,
    augment=True
)

train_size = int(TRAIN_SPLIT * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size]
)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
    num_workers=2, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=2, pin_memory=True
)

print(f"\n训练集patches: {train_size}")
print(f"验证集patches: {val_size}")
print(f"训练批次数: {len(train_loader)}")

## 7. Visualize Sample Patches

In [None]:
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

for i in range(5):
    lr_patch, hr_patch = full_dataset[i]
    
    lr_np = lr_patch.numpy().transpose(1, 2, 0)
    hr_np = hr_patch.numpy().transpose(1, 2, 0)
    
    axes[0, i].imshow(lr_np)
    axes[0, i].set_title(f'LR {LR_PATCH_SIZE}×{LR_PATCH_SIZE}')
    axes[0, i].axis('off')
    
    axes[1, i].imshow(hr_np)
    axes[1, i].set_title(f'HR {HR_PATCH_SIZE}×{HR_PATCH_SIZE}')
    axes[1, i].axis('off')

plt.tight_layout()
plt.savefig(CHECKPOINT_DIR / 'patch_samples.png', dpi=150)
plt.show()

## 8. Model Initialization

In [None]:
model = UNet16xSmall(n_channels=3, n_classes=3, base_channels=BASE_CHANNELS).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"\n模型参数: {total_params:,}")

criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
scaler = GradScaler() if USE_MIXED_PRECISION else None

## 9. Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, scaler, device, use_amp):
    model.train()
    running_loss = 0.0
    
    pbar = tqdm(loader, desc='训练')
    for lr, hr in pbar:
        lr, hr = lr.to(device), hr.to(device)
        optimizer.zero_grad()
        
        if use_amp and scaler:
            with autocast(device_type='cuda'):
                out = model(lr)
                loss = criterion(out, hr)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out = model(lr)
            loss = criterion(out, hr)
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return running_loss / len(loader)


def validate(model, loader, criterion, device, use_amp):
    model.eval()
    running_loss = 0.0
    
    with torch.no_grad():
        for lr, hr in loader:
            lr, hr = lr.to(device), hr.to(device)
            
            if use_amp:
                with autocast(device_type='cuda'):
                    out = model(lr)
                    loss = criterion(out, hr)
            else:
                out = model(lr)
                loss = criterion(out, hr)
            
            running_loss += loss.item()
    
    return running_loss / len(loader)

## 10. Training Loop

In [None]:
history = {'train_loss': [], 'val_loss': []}
best_val_loss = float('inf')

print("\n开始训练...\n")

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 50)
    
    train_loss = train_epoch(model, train_loader, criterion, optimizer, scaler, device, USE_MIXED_PRECISION)
    val_loss = validate(model, val_loader, criterion, device, USE_MIXED_PRECISION)
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    scheduler.step()
    
    print(f"\n训练损失: {train_loss:.6f}")
    print(f"验证损失: {val_loss:.6f}")
    print(f"学习率: {scheduler.get_last_lr()[0]:.6f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_loss': val_loss,
            'config': {'lr_patch': LR_PATCH_SIZE, 'hr_patch': HR_PATCH_SIZE, 'scale': SCALE_FACTOR}
        }, CHECKPOINT_DIR / 'best_model.pth')
        print(f"✓ 保存最佳模型")
    
    if (epoch + 1) % 20 == 0:
        torch.save(model.state_dict(), CHECKPOINT_DIR / f'epoch_{epoch+1}.pth')

print(f"\n训练完成！最佳验证损失: {best_val_loss:.6f}")

## 11. Plot Training Curve

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], label='Training', marker='o')
plt.plot(history['val_loss'], label='Validation', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('16× Training Curve (32→512)')
plt.legend()
plt.grid(True)
plt.savefig(CHECKPOINT_DIR / 'curve.png', dpi=150)
plt.show()

## 12. Inference: 256×256 → 4096×4096

将256×256图像切成8×8=64个32×32 patches，每个放大到512×512，拼接成4096×4096

In [None]:
def inference_256_to_4096(model, lr_256, patch_size=32, device='cuda'):
    """
    256×256 → 4096×4096 (16×)
    
    策略：切成8×8=64个32×32 patches
    """
    model.eval()
    
    # lr_256: (1, 3, 256, 256)
    num_patches = 256 // patch_size  # 8
    out_size = num_patches * (patch_size * 16)  # 8 * 512 = 4096
    
    sr_img = torch.zeros(1, 3, out_size, out_size).to(device)
    
    with torch.no_grad():
        with autocast(device_type='cuda'):
            for i in range(num_patches):
                for j in range(num_patches):
                    # 提取32×32 patch
                    y, x = i * patch_size, j * patch_size
                    lr_patch = lr_256[:, :, y:y+patch_size, x:x+patch_size]
                    
                    # 16×放大到512×512
                    sr_patch = model(lr_patch.to(device))
                    
                    # 放置到输出
                    out_y, out_x = i * patch_size * 16, j * patch_size * 16
                    sr_img[:, :, out_y:out_y+512, out_x:out_x+512] = sr_patch
    
    return sr_img

## 13. Test Inference

In [None]:
# 加载最佳模型
checkpoint = torch.load(CHECKPOINT_DIR / 'best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# 读取测试图像
test_hr_path = list(Path(HR_DIR).glob('*.png'))[0]
test_hr_4k = cv2.imread(str(test_hr_path))
test_hr_4k = cv2.cvtColor(test_hr_4k, cv2.COLOR_BGR2RGB)

# 生成256×256 LR
test_lr_256 = cv2.resize(test_hr_4k, (256, 256), interpolation=cv2.INTER_CUBIC)
lr_tensor = torch.from_numpy(test_lr_256.transpose(2, 0, 1)).float() / 255.0
lr_tensor = lr_tensor.unsqueeze(0)

print(f"输入: {lr_tensor.shape}")
print("推理中...")

sr_tensor = inference_256_to_4096(model, lr_tensor, patch_size=LR_PATCH_SIZE, device=device)

print(f"输出: {sr_tensor.shape}")

# 可视化
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

axes[0].imshow(test_lr_256)
axes[0].set_title('Input 256×256')
axes[0].axis('off')

sr_np = sr_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0)
sr_np = np.clip(sr_np, 0, 1)
sr_display = cv2.resize(sr_np, (512, 512), interpolation=cv2.INTER_AREA)
axes[1].imshow(sr_display)
axes[1].set_title(f'Output {sr_tensor.shape[2]}×{sr_tensor.shape[3]}')
axes[1].axis('off')

hr_display = cv2.resize(test_hr_4k / 255.0, (512, 512), interpolation=cv2.INTER_AREA)
axes[2].imshow(hr_display)
axes[2].set_title('GT 4096×4096')
axes[2].axis('off')

plt.tight_layout()
plt.savefig(CHECKPOINT_DIR / 'result.png', dpi=150)
plt.show()

## 14. Evaluation

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

def evaluate(model, hr_dir, num_samples=5):
    hr_images = sorted(list(Path(hr_dir).glob('*.png')))[:num_samples]
    psnr_scores, ssim_scores = [], []
    
    for hr_path in tqdm(hr_images, desc='评估'):
        hr_4k = cv2.imread(str(hr_path))
        hr_4k = cv2.cvtColor(hr_4k, cv2.COLOR_BGR2RGB)
        
        lr_256 = cv2.resize(hr_4k, (256, 256), interpolation=cv2.INTER_CUBIC)
        lr_tensor = torch.from_numpy(lr_256.transpose(2, 0, 1)).float() / 255.0
        lr_tensor = lr_tensor.unsqueeze(0)
        
        sr_tensor = inference_256_to_4096(model, lr_tensor, LR_PATCH_SIZE, device)
        
        sr_np = sr_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0)
        sr_np = np.clip(sr_np, 0, 1)
        hr_np = hr_4k / 255.0
        
        # 下采样计算
        sr_small = cv2.resize(sr_np, (1024, 1024), interpolation=cv2.INTER_AREA)
        hr_small = cv2.resize(hr_np, (1024, 1024), interpolation=cv2.INTER_AREA)
        
        psnr_scores.append(psnr(hr_small, sr_small, data_range=1.0))
        ssim_scores.append(ssim(hr_small, sr_small, data_range=1.0, channel_axis=2))
    
    print(f"\n=== 评估结果 ===")
    print(f"PSNR: {np.mean(psnr_scores):.2f} dB")
    print(f"SSIM: {np.mean(ssim_scores):.4f}")
    return np.mean(psnr_scores), np.mean(ssim_scores)

avg_psnr, avg_ssim = evaluate(model, HR_DIR, 5)