# U-Net Based 4K Image Super-Resolution Training

This notebook implements a U-Net architecture for 4K image super-resolution tasks. The model learns to transform 256×256 low-resolution images into 4096×4096 high-resolution outputs.

## Project Overview
- **Task**: 4K Image Super-Resolution (SR)
- **Architecture**: Deep U-Net with increased depth and channel numbers
- **Input**: Low-resolution images (256×256)
- **Output**: High-resolution images (4096×4096)
- **Scale Factor**: 16×
- **Loss Function**: L1 Loss (Mean Absolute Error)

---

## 1. Import Dependencies

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os
from tqdm import tqdm

## 2. Device Configuration

Check if CUDA GPU is available for accelerated training. The model will automatically use GPU if available, otherwise fall back to CPU.

In [None]:
# 检查GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

## 3. Dataset Class

### SRDataset4K: 4K Super-Resolution Dataset

This custom PyTorch Dataset class handles loading paired high-resolution (HR) and low-resolution (LR) images.

**Key Features:**
- Loads paired HR-LR images from separate directories
- HR images are at 4096×4096 resolution
- LR images are at 256×256 resolution
- Applies transformations (e.g., converting to tensors)
- Ensures HR and LR image counts match

**Process Flow:**
1. Load LR image (256×256)
2. Load corresponding HR image (4096×4096)
3. Convert both to tensors
4. Return (LR, HR) pair

In [None]:
class SRDataset4K(Dataset):
    """4K超分辨率数据集"""
    def __init__(self, hr_dir, lr_dir, transform=None):
        self.hr_dir = Path(hr_dir)
        self.lr_dir = Path(lr_dir)
        self.transform = transform
        
        # 获取所有图片文件
        self.hr_images = sorted(list(self.hr_dir.glob('*.png')))
        self.lr_images = sorted(list(self.lr_dir.glob('*.png')))
        
        assert len(self.hr_images) == len(self.lr_images), "HR和LR图片数量不匹配"
        
    def __len__(self):
        return len(self.hr_images)
    
    def __getitem__(self, idx):
        # 读取图片
        hr_img = Image.open(self.hr_images[idx]).convert('RGB')
        lr_img = Image.open(self.lr_images[idx]).convert('RGB')
        
        if self.transform:
            hr_img = self.transform(hr_img)
            lr_img = self.transform(lr_img)
        
        return lr_img, hr_img

## 4. U-Net Model Architecture

### Overview of U-Net for 4K Super-Resolution

This implementation uses a **deep U-Net** architecture designed specifically for 16× super-resolution (256×256 → 4096×4096).

---

### 4.1 DoubleConv Block

The basic building block of our U-Net. Each DoubleConv consists of:
- **Conv2d** (3×3 kernel, padding=1) → preserves spatial dimensions
- **BatchNorm2d** → stabilizes training for larger models
- **ReLU** activation (inplace for memory efficiency)
- **Conv2d** (3×3 kernel, padding=1)
- **BatchNorm2d**
- **ReLU** activation

---

### 4.2 Down Block (Encoder)

Downsampling block that reduces spatial dimensions while increasing feature channels:
1. **MaxPool2d** (2×2) → reduces spatial size by half
2. **DoubleConv** → extracts features at this resolution

**Purpose**: Extract hierarchical features from coarse to fine.

---

### 4.3 Up Block (Decoder)

Upsampling block that increases spatial dimensions:
1. **ConvTranspose2d** (2×2, stride=2) → doubles spatial size
2. **Concatenate** with skip connection from encoder
3. **DoubleConv** → refines the concatenated features

**Skip Connections**: Combines low-level and high-level features for better detail recovery.

---

### 4.4 Complete U-Net Architecture for 4K

#### Encoder Path (Downsampling):
```
Input: 3 channels (RGB), 256×256
    ↓ [DoubleConv]
  64 channels, 256×256
    ↓ [Down1: MaxPool + DoubleConv]
 128 channels, 128×128
    ↓ [Down2: MaxPool + DoubleConv]
 256 channels, 64×64
    ↓ [Down3: MaxPool + DoubleConv]
 512 channels, 32×32
    ↓ [Down4: MaxPool + DoubleConv]
1024 channels, 16×16 (Bottleneck)
```

#### Decoder Path (Upsampling with 16× scale):
```
1024 channels, 16×16
    ↓ [Up1: TransConv + Concat(512 from Down3) + DoubleConv]
 512 channels, 32×32
    ↓ [Up2: TransConv + Concat(256 from Down2) + DoubleConv]
 256 channels, 64×64
    ↓ [Up3: TransConv + Concat(128 from Down1) + DoubleConv]
 128 channels, 128×128
    ↓ [Up4: TransConv + Concat(64 from inc) + DoubleConv]
  64 channels, 256×256
    ↓ [Up5: TransConv + DoubleConv] 16× upsampling
  64 channels, 512×512
    ↓ [Up6: TransConv + DoubleConv]
  64 channels, 1024×1024
    ↓ [Up7: TransConv + DoubleConv]
  64 channels, 2048×2048
    ↓ [Up8: TransConv + DoubleConv]
  64 channels, 4096×4096
    ↓ [Output Conv 1×1]
   3 channels, 4096×4096 (Output RGB image)
```

---

### Model Optimizations for 4K:
1. **Deep architecture**: 4 downsampling + 8 upsampling layers for 16× scale
2. **BatchNorm**: Added for training stability with high resolution
3. **Transposed convolutions**: Learnable upsampling for better quality
4. **Progressive upsampling**: Multiple stages to reach 4096×4096

**Total Parameters**: ~55M (deeper model for complex 16× SR task)

In [None]:
class DoubleConv(nn.Module):
    """双卷积块 (Conv -> BN -> ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """下采样块 MaxPool -> DoubleConv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """上采样块 ConvTranspose -> DoubleConv"""
    def __init__(self, in_channels, out_channels, skip_connection=True):
        super().__init__()
        self.skip_connection = skip_connection
        # 使用转置卷积进行上采样
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        
        if skip_connection:
            self.conv = DoubleConv(in_channels, out_channels)
        else:
            self.conv = DoubleConv(in_channels // 2, out_channels)
    
    def forward(self, x1, x2=None):
        x1 = self.up(x1)
        # 拼接跳跃连接
        if self.skip_connection and x2 is not None:
            x = torch.cat([x2, x1], dim=1)
        else:
            x = x1
        return self.conv(x)


class UNet4K(nn.Module):
    """深度U-Net用于4K超分辨率 (256x256 -> 4096x4096, 16x放大)"""
    def __init__(self, n_channels=3, n_classes=3):
        super(UNet4K, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        
        # 编码器 (256 -> 16)
        self.inc = DoubleConv(n_channels, 64)        # 256x256
        self.down1 = Down(64, 128)                   # 128x128
        self.down2 = Down(128, 256)                  # 64x64
        self.down3 = Down(256, 512)                  # 32x32
        self.down4 = Down(512, 1024)                 # 16x16 (bottleneck)
        
        # 解码器 (16 -> 256 with skip connections)
        self.up1 = Up(1024 + 512, 512)               # 32x32
        self.up2 = Up(512 + 256, 256)                # 64x64
        self.up3 = Up(256 + 128, 128)                # 128x128
        self.up4 = Up(128 + 64, 64)                  # 256x256
        
        # 额外的上采样层 (256 -> 4096, 16x upsampling)
        self.up5 = Up(64, 64, skip_connection=False)   # 512x512
        self.up6 = Up(64, 64, skip_connection=False)   # 1024x1024
        self.up7 = Up(64, 64, skip_connection=False)   # 2048x2048
        self.up8 = Up(64, 64, skip_connection=False)   # 4096x4096
        
        # 输出层
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
    
    def forward(self, x):
        # 编码路径
        x1 = self.inc(x)      # 64, 256x256
        x2 = self.down1(x1)   # 128, 128x128
        x3 = self.down2(x2)   # 256, 64x64
        x4 = self.down3(x3)   # 512, 32x32
        x5 = self.down4(x4)   # 1024, 16x16
        
        # 解码路径（带跳跃连接）
        x = self.up1(x5, x4)  # 512, 32x32
        x = self.up2(x, x3)   # 256, 64x64
        x = self.up3(x, x2)   # 128, 128x128
        x = self.up4(x, x1)   # 64, 256x256
        
        # 额外的上采样到4K
        x = self.up5(x)       # 64, 512x512
        x = self.up6(x)       # 64, 1024x1024
        x = self.up7(x)       # 64, 2048x2048
        x = self.up8(x)       # 64, 4096x4096
        
        # 输出
        return self.outc(x)   # 3, 4096x4096

## 5. Training Configuration

### Hyperparameters

These parameters are tuned for **4K super-resolution training**:

- **BATCH_SIZE = 1**: Due to 4K resolution, use batch size of 1 to fit in GPU memory
- **LEARNING_RATE = 1e-4**: Lower learning rate for stable training with large images
- **NUM_EPOCHS = 50**: More epochs needed for complex 16× SR task
- **TRAIN_SPLIT = 0.9**: 90% training data, 10% validation data

### Data Paths
- `HR_DIR`: Directory containing 4096×4096 high-resolution ground truth images
- `LR_DIR`: Directory containing 256×256 low-resolution input images
- `CHECKPOINT_DIR`: Where to save model checkpoints and training results

In [None]:
# 超参数 - 针对4K训练
BATCH_SIZE = 1  # 4K图片需要大量显存，使用batch size=1
LEARNING_RATE = 1e-4  # 较低的学习率
NUM_EPOCHS = 50  # 更多的训练轮数
TRAIN_SPLIT = 0.9  # 90%训练，10%验证

# 数据路径
HR_DIR = './dataset_4k/high_resolution'
LR_DIR = './dataset_4k/low_resolution'
CHECKPOINT_DIR = Path('./checkpoints_4k')
CHECKPOINT_DIR.mkdir(exist_ok=True)

print(f"批次大小: {BATCH_SIZE}")
print(f"学习率: {LEARNING_RATE}")
print(f"训练轮数: {NUM_EPOCHS}")
print("优化策略: 深度U-Net + BatchNorm + 渐进式上采样")

## 6. Data Loading and Preprocessing

### Data Pipeline:
1. **Transform**: Convert PIL images to PyTorch tensors (values in [0, 1])
2. **Dataset**: Load full dataset with paired HR-LR images
3. **Split**: Randomly divide into training and validation sets
4. **DataLoader**: Create batch iterators for efficient GPU utilization

**Note**: `num_workers=0` is used to avoid multiprocessing issues. `pin_memory=True` speeds up data transfer to GPU.

In [None]:
# 数据转换
transform = transforms.Compose([
    transforms.ToTensor(),
])

# 加载数据集
full_dataset = SRDataset4K(HR_DIR, LR_DIR, transform=transform)

# 划分训练集和验证集
train_size = int(TRAIN_SPLIT * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# 数据加载器
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                         num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                       num_workers=0, pin_memory=True)

print(f"总样本数: {len(full_dataset)}")
print(f"训练集: {train_size}")
print(f"验证集: {val_size}")
print(f"训练批次数: {len(train_loader)}")

## 7. Model Initialization

### Components:
- **Model**: UNet4K with 3 input channels (RGB) and 3 output channels (RGB)
- **Loss Function**: L1 Loss (MAE) - better preserves sharp edges than L2 loss
- **Optimizer**: Adam optimizer with learning rate 1e-4
- **LR Scheduler**: StepLR - reduces learning rate by 0.5 every 15 epochs

The model parameters are displayed to verify the deep architecture (~55M parameters).

In [None]:
# 初始化模型
model = UNet4K(n_channels=3, n_classes=3).to(device)

# 损失函数和优化器
criterion = nn.L1Loss()  # L1损失（MAE）
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# 学习率调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.5)

# 打印模型信息
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n模型参数总数: {total_params:,}")
print(f"可训练参数: {trainable_params:,}")

## 8. Training and Validation Functions

### train_epoch()
Trains the model for one complete epoch:
1. Set model to training mode
2. For each batch:
   - Forward pass: predict 4K SR images
   - Calculate L1 loss between predictions and ground truth
   - Backward pass: compute gradients
   - Update model weights
3. Return average loss for the epoch

### validate()
Evaluates model on validation set:
1. Set model to evaluation mode (disables dropout, etc.)
2. Disable gradient computation (`torch.no_grad()`)
3. Calculate loss on validation batches
4. Return average validation loss

**Purpose**: Validation loss helps monitor overfitting and select the best model.

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """训练一个epoch"""
    model.train()
    running_loss = 0.0
    
    pbar = tqdm(train_loader, desc='训练')
    for lr_imgs, hr_imgs in pbar:
        lr_imgs = lr_imgs.to(device)
        hr_imgs = hr_imgs.to(device)
        
        # 前向传播
        optimizer.zero_grad()
        outputs = model(lr_imgs)
        loss = criterion(outputs, hr_imgs)
        
        # 反向传播
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return running_loss / len(train_loader)


def validate(model, val_loader, criterion, device):
    """验证"""
    model.eval()
    running_loss = 0.0
    
    with torch.no_grad():
        for lr_imgs, hr_imgs in val_loader:
            lr_imgs = lr_imgs.to(device)
            hr_imgs = hr_imgs.to(device)
            
            outputs = model(lr_imgs)
            loss = criterion(outputs, hr_imgs)
            
            running_loss += loss.item()
    
    return running_loss / len(val_loader)

## 9. Training Loop

This cell executes the main training loop for all epochs.

### Process for Each Epoch:
1. **Train**: Run one epoch on training set
2. **Validate**: Evaluate on validation set
3. **Record**: Save losses to history
4. **Scheduler**: Update learning rate
5. **Checkpoint**:
   - Save best model when validation loss improves
   - Save checkpoint every 10 epochs for backup

### Monitoring:
- Training loss should decrease steadily
- Validation loss should track training loss (gap indicates overfitting)
- Progress bars show real-time batch-level loss

The best model is saved based on lowest validation loss, ensuring the model generalizes well to unseen data.

In [None]:
# 记录训练历史
history = {
    'train_loss': [],
    'val_loss': []
}

best_val_loss = float('inf')

print("\n开始训练...\n")

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 50)
    
    # 训练
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # 验证
    val_loss = validate(model, val_loader, criterion, device)
    
    # 记录
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    
    # 更新学习率
    scheduler.step()
    
    print(f"\n训练损失: {train_loss:.6f}")
    print(f"验证损失: {val_loss:.6f}")
    print(f"当前学习率: {scheduler.get_last_lr()[0]:.6f}")
    
    # 保存最佳模型
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_loss,
        }, CHECKPOINT_DIR / 'best_model.pth')
        print(f"✓ 保存最佳模型 (验证损失: {val_loss:.6f})")
    
    # 定期保存检查点
    if (epoch + 1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_loss': val_loss,
        }, CHECKPOINT_DIR / f'checkpoint_epoch_{epoch+1}.pth')
        print(f"✓ 保存检查点: epoch_{epoch+1}")

print("\n训练完成！")
print(f"最佳验证损失: {best_val_loss:.6f}")

## 10. Training Curve Visualization

Plots the training and validation loss curves over all epochs.

**What to Look For:**
- **Decreasing trend**: Both losses should decrease over time
- **Convergence**: Losses should stabilize toward the end
- **Overfitting**: If validation loss increases while training loss decreases, the model is overfitting
- **Underfitting**: If both losses remain high, the model needs more capacity or training

The curve is saved as an image for documentation and analysis.

In [None]:
# 绘制损失曲线
plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], label='training loss', marker='o')
plt.plot(history['val_loss'], label='validation loss', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Loss (L1)')
plt.title('4K Super-Resolution Training Loss Curve')
plt.legend()
plt.grid(True)
plt.savefig(CHECKPOINT_DIR / 'training_curve.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"训练曲线已保存: {CHECKPOINT_DIR / 'training_curve.png'}")

## 11. Load Best Model

Load the checkpoint with the lowest validation loss for inference and evaluation.

The model is set to evaluation mode to ensure consistent behavior (e.g., no dropout).

In [None]:
# 加载最佳模型
checkpoint = torch.load(CHECKPOINT_DIR / 'best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"已加载最佳模型 (Epoch {checkpoint['epoch']+1}, 验证损失: {checkpoint['val_loss']:.6f})")

## 12. Visual Comparison of Results

Generates side-by-side comparisons of:
1. **Low Resolution Input**: The 256×256 LR image fed to the model
2. **Model Output**: The 4K (4096×4096) super-resolved image generated by our U-Net
3. **Ground Truth**: The actual 4096×4096 high-resolution image

**Note**: Images are downsampled for display purposes.

**Purpose**: Visually assess the model's performance in recovering fine details, textures, and overall image quality at 4K resolution.

In [None]:
import cv2

# 可视化测试结果
num_samples = min(3, len(val_dataset))  # 4K图片较大，只显示3个样本
fig, axes = plt.subplots(3, num_samples, figsize=(15, 15))

with torch.no_grad():
    for i in range(num_samples):
        # 获取一个样本
        lr_img, hr_img = val_dataset[i]
        lr_img_input = lr_img.unsqueeze(0).to(device)
        
        # 生成超分辨率图片
        sr_img = model(lr_img_input)
        
        # 转换为numpy显示（下采样用于显示）
        lr_img_np = lr_img.cpu().numpy().transpose(1, 2, 0)
        hr_img_np = hr_img.cpu().numpy().transpose(1, 2, 0)
        sr_img_np = sr_img.squeeze(0).cpu().numpy().transpose(1, 2, 0)
        
        # 裁剪到[0, 1]
        lr_img_np = np.clip(lr_img_np, 0, 1)
        hr_img_np = np.clip(hr_img_np, 0, 1)
        sr_img_np = np.clip(sr_img_np, 0, 1)
        
        # 下采样4K图片用于显示
        display_size = 512
        hr_img_display = cv2.resize(hr_img_np, (display_size, display_size), interpolation=cv2.INTER_AREA)
        sr_img_display = cv2.resize(sr_img_np, (display_size, display_size), interpolation=cv2.INTER_AREA)
        
        # 显示低分辨率
        axes[0, i].imshow(lr_img_np)
        axes[0, i].set_title('LR input (256x256)')
        axes[0, i].axis('off')
        
        # 显示生成的高分辨率
        axes[1, i].imshow(sr_img_display)
        axes[1, i].set_title('Model output (4096x4096)')
        axes[1, i].axis('off')
        
        # 显示真实高分辨率
        axes[2, i].imshow(hr_img_display)
        axes[2, i].set_title('Ground truth (4096x4096)')
        axes[2, i].axis('off')

plt.tight_layout()
plt.savefig(CHECKPOINT_DIR / 'test_results.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"测试结果已保存: {CHECKPOINT_DIR / 'test_results.png'}")

## 13. Quantitative Evaluation Metrics

### PSNR (Peak Signal-to-Noise Ratio)
- Measures pixel-wise difference between images
- **Higher is better** (typically 25-35 dB for SR tasks)
- PSNR = 20 × log10(MAX / √MSE)
- Sensitive to pixel-level accuracy but may not correlate perfectly with perceptual quality

### SSIM (Structural Similarity Index)
- Measures structural similarity considering luminance, contrast, and structure
- **Range**: 0 to 1 (1 = identical images)
- Better correlates with human perception than PSNR
- Values above 0.95 indicate excellent quality

**Note**: For 4K images, metrics are calculated on downsampled versions to save memory.

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

def calculate_metrics_4k(model, val_loader, device):
    """计算PSNR和SSIM (4K版本)"""
    model.eval()
    psnr_scores = []
    ssim_scores = []
    
    with torch.no_grad():
        for lr_imgs, hr_imgs in tqdm(val_loader, desc='计算评估指标'):
            lr_imgs = lr_imgs.to(device)
            hr_imgs = hr_imgs.to(device)
            
            sr_imgs = model(lr_imgs)
            
            # 转换为numpy
            sr_imgs_np = sr_imgs.cpu().numpy().transpose(0, 2, 3, 1)
            hr_imgs_np = hr_imgs.cpu().numpy().transpose(0, 2, 3, 1)
            
            # 计算每张图片的指标
            for sr_img, hr_img in zip(sr_imgs_np, hr_imgs_np):
                sr_img = np.clip(sr_img, 0, 1)
                hr_img = np.clip(hr_img, 0, 1)
                
                # 为了节省内存，在下采样的图片上计算指标
                sr_img_small = cv2.resize(sr_img, (1024, 1024), interpolation=cv2.INTER_AREA)
                hr_img_small = cv2.resize(hr_img, (1024, 1024), interpolation=cv2.INTER_AREA)
                
                psnr_score = psnr(hr_img_small, sr_img_small, data_range=1.0)
                ssim_score = ssim(hr_img_small, sr_img_small, data_range=1.0, channel_axis=2)
                
                psnr_scores.append(psnr_score)
                ssim_scores.append(ssim_score)
    
    return np.mean(psnr_scores), np.mean(ssim_scores)

# 计算指标
avg_psnr, avg_ssim = calculate_metrics_4k(model, val_loader, device)

print("\n=== 评估指标 ===")
print(f"平均PSNR: {avg_psnr:.2f} dB")
print(f"平均SSIM: {avg_ssim:.4f}")

## 14. Training Summary

Generates a comprehensive summary of the training session including:
- Model architecture details
- Hyperparameters used
- Final performance metrics (loss, PSNR, SSIM)
- File paths to saved models and results

This summary is saved as a text file for future reference and experiment tracking.

In [None]:
# 保存训练总结
summary = f"""
4K超分辨率训练总结
{'='*50}
模型: UNet4K (深度U-Net)
输入分辨率: 256×256
输出分辨率: 4096×4096
缩放倍数: 16×

训练配置:
- 训练轮数: {NUM_EPOCHS}
- 批次大小: {BATCH_SIZE}
- 学习率: {LEARNING_RATE}
- 损失函数: L1 Loss
- 优化器: Adam
- 学习率调度: StepLR (step=15, gamma=0.5)

最终结果:
- 最佳验证损失: {best_val_loss:.6f}
- 平均PSNR: {avg_psnr:.2f} dB
- 平均SSIM: {avg_ssim:.4f}

模型保存位置: {CHECKPOINT_DIR / 'best_model.pth'}
"""

print(summary)

with open(CHECKPOINT_DIR / 'training_summary.txt', 'w') as f:
    f.write(summary)

print(f"训练总结已保存: {CHECKPOINT_DIR / 'training_summary.txt'}")