# Cascade 8× Super-Resolution (64×64 → 512×512)

## 方案2: 级联8×模型

### 训练策略:
- **LR patch**: 64×64
- **HR patch**: 512×512
- **Scale**: 8×
- **优点**: patch大小合理，训练更稳定
- **推理**: 需要级联或二次放大

### 推理流程 (256×256 → 4096×4096):

#### 方法A: 8× + 2× 级联
```
256×256 LR
  ↓ 切成4×4=16个 64×64 patches
  ↓ 每个通过8×模型 → 512×512
  ↓ 拼接 → 2048×2048
  ↓ 简单2×上采样(bicubic或轻量模型)
  ↓ 4096×4096 HR
```

#### 方法B: 两次8×级联
```
256×256 LR
  ↓ 第一个8×模型 → 2048×2048
  ↓ 切成4×4=16个 512×512 patches
  ↓ 每个下采样到64×64
  ↓ 通过第二个8×模型 → 512×512
  ↓ 拼接 → 4096×4096 HR
```

### 显存占用:
- Batch=8: 8×3×512×512 ≈ 24MB (输出)
- 总显存: ~1GB

---

## 1. Import Dependencies

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os
from tqdm import tqdm
import cv2
import random

## 2. Device Configuration

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB')

## 3. Patch Dataset (64×64 → 512×512)

In [None]:
class PatchSRDataset8x(Dataset):
    """64×64 → 512×512 (8×) patch数据集"""
    def __init__(self, hr_dir, lr_patch_size=64, scale_factor=8, 
                 patches_per_image=15, augment=True):
        self.hr_dir = Path(hr_dir)
        self.lr_patch_size = lr_patch_size
        self.hr_patch_size = lr_patch_size * scale_factor
        self.scale_factor = scale_factor
        self.patches_per_image = patches_per_image
        self.augment = augment
        
        self.hr_images = sorted(list(self.hr_dir.glob('*.png')))
        
        print(f"数据集信息:")
        print(f"  图像数量: {len(self.hr_images)}")
        print(f"  LR patch: {lr_patch_size}×{lr_patch_size}")
        print(f"  HR patch: {self.hr_patch_size}×{self.hr_patch_size}")
        print(f"  缩放倍数: {scale_factor}×")
        print(f"  每张图像patch数: {patches_per_image}")
        print(f"  总patch数/epoch: {len(self.hr_images) * patches_per_image}")
    
    def __len__(self):
        return len(self.hr_images) * self.patches_per_image
    
    def augment_patch(self, lr_patch, hr_patch):
        """数据增强"""
        if random.random() > 0.5:
            lr_patch = np.fliplr(lr_patch)
            hr_patch = np.fliplr(hr_patch)
        
        if random.random() > 0.5:
            lr_patch = np.flipud(lr_patch)
            hr_patch = np.flipud(hr_patch)
        
        k = random.randint(0, 3)
        if k > 0:
            lr_patch = np.rot90(lr_patch, k)
            hr_patch = np.rot90(hr_patch, k)
        
        return lr_patch.copy(), hr_patch.copy()
    
    def __getitem__(self, idx):
        img_idx = idx // self.patches_per_image
        
        hr_img = cv2.imread(str(self.hr_images[img_idx]))
        hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)
        
        h, w = hr_img.shape[:2]
        max_y = h - self.hr_patch_size
        max_x = w - self.hr_patch_size
        
        if max_y <= 0 or max_x <= 0:
            hr_patch = cv2.resize(hr_img, (self.hr_patch_size, self.hr_patch_size))
        else:
            y = random.randint(0, max_y)
            x = random.randint(0, max_x)
            hr_patch = hr_img[y:y+self.hr_patch_size, x:x+self.hr_patch_size]
        
        lr_patch = cv2.resize(hr_patch, (self.lr_patch_size, self.lr_patch_size), 
                             interpolation=cv2.INTER_CUBIC)
        
        if self.augment:
            lr_patch, hr_patch = self.augment_patch(lr_patch, hr_patch)
        
        lr_tensor = torch.from_numpy(lr_patch.transpose(2, 0, 1)).float() / 255.0
        hr_tensor = torch.from_numpy(hr_patch.transpose(2, 0, 1)).float() / 255.0
        
        return lr_tensor, hr_tensor

## 4. U-Net 8× Model

设计原则：
- 64×64输入足够大，可以有更多下采样
- 使用skip connections保留细节
- 移除BatchNorm节省显存

In [None]:
class ResBlock(nn.Module):
    """残差块"""
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, padding=1)
        )
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        return self.relu(x + self.conv(x))


class UNet8x(nn.Module):
    """8×SR U-Net (64×64 → 512×512)
    
    架构：
    - 编码器: 64→32→16→8 (3次下采样)
    - 解码器: 8→16→32→64 (恢复到输入分辨率)
    - 额外上采样: 64→128→256→512 (8×放大)
    """
    def __init__(self, n_channels=3, n_classes=3, base_ch=32):
        super().__init__()
        
        # 初始卷积
        self.inc = nn.Sequential(
            nn.Conv2d(n_channels, base_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            ResBlock(base_ch)
        )
        
        # 编码器 (64 → 32 → 16 → 8)
        self.down1 = nn.Sequential(
            nn.Conv2d(base_ch, base_ch*2, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            ResBlock(base_ch*2)
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(base_ch*2, base_ch*4, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            ResBlock(base_ch*4)
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(base_ch*4, base_ch*8, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            ResBlock(base_ch*8)
        )
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            ResBlock(base_ch*8),
            ResBlock(base_ch*8)
        )
        
        # 解码器 (8 → 16 → 32 → 64) 带skip connections
        self.up1 = self._make_up_block(base_ch*8, base_ch*4, base_ch*4)
        self.up2 = self._make_up_block(base_ch*4, base_ch*2, base_ch*2)
        self.up3 = self._make_up_block(base_ch*2, base_ch, base_ch)
        
        # 额外上采样到8× (64 → 128 → 256 → 512)
        self.up4 = self._make_simple_up(base_ch, base_ch)
        self.up5 = self._make_simple_up(base_ch, base_ch)
        self.up6 = self._make_simple_up(base_ch, base_ch)
        
        # 输出
        self.outc = nn.Conv2d(base_ch, n_classes, 1)
    
    def _make_up_block(self, in_ch, skip_ch, out_ch):
        """上采样块（带skip connection）"""
        return nn.ModuleDict({
            'up': nn.Upsample(scale_factor=2, mode='nearest'),
            'conv': nn.Sequential(
                nn.Conv2d(in_ch + skip_ch, out_ch, 3, padding=1),
                nn.ReLU(inplace=True),
                ResBlock(out_ch)
            )
        })
    
    def _make_simple_up(self, in_ch, out_ch):
        """简单上采样（无skip）"""
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # 编码
        x1 = self.inc(x)        # 32, 64×64
        x2 = self.down1(x1)     # 64, 32×32
        x3 = self.down2(x2)     # 128, 16×16
        x4 = self.down3(x3)     # 256, 8×8
        
        # Bottleneck
        x = self.bottleneck(x4) # 256, 8×8
        
        # 解码（带skip）
        x = self.up1['up'](x)
        x = self.up1['conv'](torch.cat([x, x3], dim=1))  # 128, 16×16
        
        x = self.up2['up'](x)
        x = self.up2['conv'](torch.cat([x, x2], dim=1))  # 64, 32×32
        
        x = self.up3['up'](x)
        x = self.up3['conv'](torch.cat([x, x1], dim=1))  # 32, 64×64
        
        # 额外上采样到8×
        x = self.up4(x)         # 32, 128×128
        x = self.up5(x)         # 32, 256×256
        x = self.up6(x)         # 32, 512×512
        
        return self.outc(x)     # 3, 512×512

## 5. Training Configuration

In [None]:
# 训练配置
LR_PATCH_SIZE = 64
HR_PATCH_SIZE = 512
SCALE_FACTOR = 8

BATCH_SIZE = 8
LEARNING_RATE = 2e-4
NUM_EPOCHS = 60
PATCHES_PER_IMAGE = 15

BASE_CHANNELS = 32
USE_MIXED_PRECISION = True
TRAIN_SPLIT = 0.9

HR_DIR = './dataset_4k/high_resolution'
CHECKPOINT_DIR = Path('./checkpoints_8x_cascade')
CHECKPOINT_DIR.mkdir(exist_ok=True)

print("=== 方案2: 级联8×模型 ===")
print(f"LR patch: {LR_PATCH_SIZE}×{LR_PATCH_SIZE}")
print(f"HR patch: {HR_PATCH_SIZE}×{HR_PATCH_SIZE}")
print(f"放大倍数: {SCALE_FACTOR}×")
print(f"Batch size: {BATCH_SIZE}")
print(f"训练轮数: {NUM_EPOCHS}")

## 6. Data Loading

In [None]:
full_dataset = PatchSRDataset8x(
    hr_dir=HR_DIR,
    lr_patch_size=LR_PATCH_SIZE,
    scale_factor=SCALE_FACTOR,
    patches_per_image=PATCHES_PER_IMAGE,
    augment=True
)

train_size = int(TRAIN_SPLIT * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size]
)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=2, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=2, pin_memory=True
)

print(f"\n训练集patches: {train_size}")
print(f"验证集patches: {val_size}")

## 7. Visualize Patches

In [None]:
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

for i in range(5):
    lr, hr = full_dataset[i]
    
    axes[0, i].imshow(lr.numpy().transpose(1, 2, 0))
    axes[0, i].set_title(f'LR {LR_PATCH_SIZE}×{LR_PATCH_SIZE}')
    axes[0, i].axis('off')
    
    axes[1, i].imshow(hr.numpy().transpose(1, 2, 0))
    axes[1, i].set_title(f'HR {HR_PATCH_SIZE}×{HR_PATCH_SIZE}')
    axes[1, i].axis('off')

plt.tight_layout()
plt.savefig(CHECKPOINT_DIR / 'samples.png', dpi=150)
plt.show()

## 8. Model Initialization

In [None]:
model = UNet8x(n_channels=3, n_classes=3, base_ch=BASE_CHANNELS).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"\n模型参数: {total_params:,}")

criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
scaler = GradScaler() if USE_MIXED_PRECISION else None

## 9. Training Functions

In [None]:
def train_epoch(model, loader, criterion, optimizer, scaler, device, use_amp):
    model.train()
    running_loss = 0.0
    
    pbar = tqdm(loader, desc='训练')
    for lr, hr in pbar:
        lr, hr = lr.to(device), hr.to(device)
        optimizer.zero_grad()
        
        if use_amp and scaler:
            with autocast(device_type='cuda'):
                out = model(lr)
                loss = criterion(out, hr)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out = model(lr)
            loss = criterion(out, hr)
            loss.backward()
            optimizer.step()
        
        running_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return running_loss / len(loader)


def validate(model, loader, criterion, device, use_amp):
    model.eval()
    running_loss = 0.0
    
    with torch.no_grad():
        for lr, hr in loader:
            lr, hr = lr.to(device), hr.to(device)
            
            if use_amp:
                with autocast(device_type='cuda'):
                    out = model(lr)
                    loss = criterion(out, hr)
            else:
                out = model(lr)
                loss = criterion(out, hr)
            
            running_loss += loss.item()
    
    return running_loss / len(loader)

## 10. Training Loop

In [None]:
history = {'train_loss': [], 'val_loss': []}
best_val_loss = float('inf')

print("\n开始训练...\n")

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 50)
    
    train_loss = train_epoch(model, train_loader, criterion, optimizer, scaler, device, USE_MIXED_PRECISION)
    val_loss = validate(model, val_loader, criterion, device, USE_MIXED_PRECISION)
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    scheduler.step()
    
    print(f"\n训练损失: {train_loss:.6f}")
    print(f"验证损失: {val_loss:.6f}")
    print(f"学习率: {scheduler.get_last_lr()[0]:.6f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_loss': val_loss,
            'config': {'lr_patch': LR_PATCH_SIZE, 'hr_patch': HR_PATCH_SIZE, 'scale': SCALE_FACTOR}
        }, CHECKPOINT_DIR / 'best_model_8x.pth')
        print(f"✓ 保存最佳模型")
    
    if (epoch + 1) % 20 == 0:
        torch.save(model.state_dict(), CHECKPOINT_DIR / f'epoch_{epoch+1}.pth')

print(f"\n训练完成！最佳验证损失: {best_val_loss:.6f}")

## 11. Plot Curve

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], label='Training', marker='o')
plt.plot(history['val_loss'], label='Validation', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('8× Training Curve (64→512)')
plt.legend()
plt.grid(True)
plt.savefig(CHECKPOINT_DIR / 'curve.png', dpi=150)
plt.show()

## 12. Inference Methods

### 方法A: 8× + 2× Bicubic (推荐)

In [None]:
def inference_256_to_4096_methodA(model, lr_256, patch_size=64, device='cuda'):
    """
    方法A: 8× + 2× bicubic
    
    256×256 → (patch 8×) → 2048×2048 → (bicubic 2×) → 4096×4096
    """
    model.eval()
    
    # Step 1: 256 → 2048 (8×) 使用patch
    num_patches = 256 // patch_size  # 4
    sr_2048 = torch.zeros(1, 3, 2048, 2048).to(device)
    
    with torch.no_grad():
        with autocast(device_type='cuda'):
            for i in range(num_patches):
                for j in range(num_patches):
                    y, x = i * patch_size, j * patch_size
                    lr_patch = lr_256[:, :, y:y+patch_size, x:x+patch_size]
                    sr_patch = model(lr_patch.to(device))  # 64 → 512
                    
                    out_y, out_x = i * 512, j * 512
                    sr_2048[:, :, out_y:out_y+512, out_x:out_x+512] = sr_patch
    
    # Step 2: 2048 → 4096 (2×) 使用bicubic
    sr_2048_np = sr_2048.squeeze(0).cpu().numpy().transpose(1, 2, 0)
    sr_4096_np = cv2.resize(sr_2048_np, (4096, 4096), interpolation=cv2.INTER_CUBIC)
    sr_4096 = torch.from_numpy(sr_4096_np.transpose(2, 0, 1)).unsqueeze(0).to(device)
    
    return sr_4096

### 可选: 训练一个简单的2× SR模型

In [None]:
class Simple2xSR(nn.Module):
    """简单的2×超分辨率模型（可选）"""
    def __init__(self, n_channels=3):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(n_channels, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        # 2×上采样
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(64, 32, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, n_channels, 3, padding=1)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.upsample(x)
        return x

# 如果想训练2×模型，取消注释:
# model_2x = Simple2xSR().to(device)
# 然后训练...

## 13. Test Inference

In [None]:
# 加载最佳模型
checkpoint = torch.load(CHECKPOINT_DIR / 'best_model_8x.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# 读取测试图像
test_hr_path = list(Path(HR_DIR).glob('*.png'))[0]
test_hr_4k = cv2.imread(str(test_hr_path))
test_hr_4k = cv2.cvtColor(test_hr_4k, cv2.COLOR_BGR2RGB)

test_lr_256 = cv2.resize(test_hr_4k, (256, 256), interpolation=cv2.INTER_CUBIC)
lr_tensor = torch.from_numpy(test_lr_256.transpose(2, 0, 1)).float() / 255.0
lr_tensor = lr_tensor.unsqueeze(0)

print(f"输入: {lr_tensor.shape}")
print("推理中 (8× + 2× bicubic)...")

sr_tensor = inference_256_to_4096_methodA(model, lr_tensor, LR_PATCH_SIZE, device)

print(f"输出: {sr_tensor.shape}")

# 可视化
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

axes[0].imshow(test_lr_256)
axes[0].set_title('Input 256×256')
axes[0].axis('off')

sr_np = sr_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0)
sr_np = np.clip(sr_np, 0, 1)
sr_display = cv2.resize(sr_np, (512, 512), interpolation=cv2.INTER_AREA)
axes[1].imshow(sr_display)
axes[1].set_title(f'Output {sr_tensor.shape[2]}×{sr_tensor.shape[3]}')
axes[1].axis('off')

hr_display = cv2.resize(test_hr_4k / 255.0, (512, 512), interpolation=cv2.INTER_AREA)
axes[2].imshow(hr_display)
axes[2].set_title('GT 4096×4096')
axes[2].axis('off')

plt.tight_layout()
plt.savefig(CHECKPOINT_DIR / 'result.png', dpi=150)
plt.show()

## 14. Evaluation

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

def evaluate(model, hr_dir, num_samples=5):
    hr_images = sorted(list(Path(hr_dir).glob('*.png')))[:num_samples]
    psnr_scores, ssim_scores = [], []
    
    for hr_path in tqdm(hr_images, desc='评估'):
        hr_4k = cv2.imread(str(hr_path))
        hr_4k = cv2.cvtColor(hr_4k, cv2.COLOR_BGR2RGB)
        
        lr_256 = cv2.resize(hr_4k, (256, 256), interpolation=cv2.INTER_CUBIC)
        lr_tensor = torch.from_numpy(lr_256.transpose(2, 0, 1)).float() / 255.0
        lr_tensor = lr_tensor.unsqueeze(0)
        
        sr_tensor = inference_256_to_4096_methodA(model, lr_tensor, LR_PATCH_SIZE, device)
        
        sr_np = sr_tensor.squeeze(0).cpu().numpy().transpose(1, 2, 0)
        sr_np = np.clip(sr_np, 0, 1)
        hr_np = hr_4k / 255.0
        
        sr_small = cv2.resize(sr_np, (1024, 1024), interpolation=cv2.INTER_AREA)
        hr_small = cv2.resize(hr_np, (1024, 1024), interpolation=cv2.INTER_AREA)
        
        psnr_scores.append(psnr(hr_small, sr_small, data_range=1.0))
        ssim_scores.append(ssim(hr_small, sr_small, data_range=1.0, channel_axis=2))
    
    print(f"\n=== 评估结果 ===")
    print(f"PSNR: {np.mean(psnr_scores):.2f} dB")
    print(f"SSIM: {np.mean(ssim_scores):.4f}")
    return np.mean(psnr_scores), np.mean(ssim_scores)

avg_psnr, avg_ssim = evaluate(model, HR_DIR, 5)