# Patch-based Training with Memory Debugging

## 特性：
- ✅ 修复autocast兼容性问题
- ✅ 详细的显存监控
- ✅ 每步显示显存使用
- ✅ 自动降低batch size如果OOM
- ✅ 梯度累积备选方案

---

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import cv2
import random
from tqdm import tqdm
import gc

## 1. Memory Monitoring Utilities

In [None]:
def print_gpu_memory(tag=""):
    """打印当前GPU显存使用情况"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        max_allocated = torch.cuda.max_memory_allocated() / 1024**3
        print(f"[{tag}] GPU显存: 已分配={allocated:.2f}GB, 保留={reserved:.2f}GB, 峰值={max_allocated:.2f}GB")
        return allocated, reserved, max_allocated
    return 0, 0, 0

def clear_gpu_memory():
    """清理GPU显存"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def get_tensor_memory(tensor):
    """获取tensor占用的显存(MB)"""
    return tensor.element_size() * tensor.nelement() / 1024**2

## 2. Device Configuration

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'使用设备: {device}')

if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f'总显存: {total_memory:.2f} GB')
    
    # 重置显存统计
    torch.cuda.reset_peak_memory_stats()
    clear_gpu_memory()
    print_gpu_memory("初始状态")

## 3. Dataset

In [None]:
class PatchSRDataset(Dataset):
    """Patch数据集"""
    def __init__(self, hr_dir, lr_patch_size=64, scale_factor=8, 
                 patches_per_image=10, augment=True):
        self.hr_dir = Path(hr_dir)
        self.lr_patch_size = lr_patch_size
        self.hr_patch_size = lr_patch_size * scale_factor
        self.scale_factor = scale_factor
        self.patches_per_image = patches_per_image
        self.augment = augment
        
        self.hr_images = sorted(list(self.hr_dir.glob('*.png')))
        
        print(f"\n数据集配置:")
        print(f"  图像数量: {len(self.hr_images)}")
        print(f"  LR patch: {lr_patch_size}×{lr_patch_size}")
        print(f"  HR patch: {self.hr_patch_size}×{self.hr_patch_size}")
        print(f"  缩放倍数: {scale_factor}×")
        print(f"  总patches: {len(self.hr_images) * patches_per_image}")
    
    def __len__(self):
        return len(self.hr_images) * self.patches_per_image
    
    def augment_patch(self, lr, hr):
        if random.random() > 0.5:
            lr, hr = np.fliplr(lr), np.fliplr(hr)
        if random.random() > 0.5:
            lr, hr = np.flipud(lr), np.flipud(hr)
        k = random.randint(0, 3)
        if k > 0:
            lr, hr = np.rot90(lr, k), np.rot90(hr, k)
        return lr.copy(), hr.copy()
    
    def __getitem__(self, idx):
        img_idx = idx // self.patches_per_image
        
        hr_img = cv2.imread(str(self.hr_images[img_idx]))
        hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)
        
        h, w = hr_img.shape[:2]
        max_y, max_x = h - self.hr_patch_size, w - self.hr_patch_size
        
        if max_y <= 0 or max_x <= 0:
            hr_patch = cv2.resize(hr_img, (self.hr_patch_size, self.hr_patch_size))
        else:
            y, x = random.randint(0, max_y), random.randint(0, max_x)
            hr_patch = hr_img[y:y+self.hr_patch_size, x:x+self.hr_patch_size]
        
        lr_patch = cv2.resize(hr_patch, (self.lr_patch_size, self.lr_patch_size), 
                             interpolation=cv2.INTER_CUBIC)
        
        if self.augment:
            lr_patch, hr_patch = self.augment_patch(lr_patch, hr_patch)
        
        lr_tensor = torch.from_numpy(lr_patch.transpose(2, 0, 1)).float() / 255.0
        hr_tensor = torch.from_numpy(hr_patch.transpose(2, 0, 1)).float() / 255.0
        
        return lr_tensor, hr_tensor

## 4. Lightweight Model

In [None]:
class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch, ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch, ch, 3, padding=1)
        )
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        return self.relu(x + self.conv(x))


class TinyUNet8x(nn.Module):
    """超轻量8×SR模型 (64×64 → 512×512)
    
    设计：最小化显存占用
    """
    def __init__(self, base_ch=24):  # 减少通道数到24
        super().__init__()
        
        # 编码器 (64 → 32 → 16)
        self.inc = nn.Sequential(
            nn.Conv2d(3, base_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        self.down1 = nn.Sequential(
            nn.Conv2d(base_ch, base_ch*2, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            ResBlock(base_ch*2)
        )
        
        self.down2 = nn.Sequential(
            nn.Conv2d(base_ch*2, base_ch*4, 3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            ResBlock(base_ch*4)
        )
        
        # 上采样到8× (16 → 32 → 64 → 128 → 256 → 512)
        self.up_blocks = nn.ModuleList([
            self._make_up(base_ch*4, base_ch*2),  # 16→32
            self._make_up(base_ch*2, base_ch),    # 32→64
            self._make_up(base_ch, base_ch),      # 64→128
            self._make_up(base_ch, base_ch),      # 128→256
            self._make_up(base_ch, base_ch),      # 256→512
        ])
        
        self.outc = nn.Conv2d(base_ch, 3, 1)
    
    def _make_up(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        x = self.inc(x)
        x = self.down1(x)
        x = self.down2(x)
        
        for up in self.up_blocks:
            x = up(x)
        
        return self.outc(x)

## 5. Configuration with Auto-tuning

In [None]:
# 基础配置
LR_PATCH_SIZE = 64
SCALE_FACTOR = 8
HR_PATCH_SIZE = LR_PATCH_SIZE * SCALE_FACTOR

# 训练配置（会自动调整）
INITIAL_BATCH_SIZE = 4  # 从小的batch size开始
GRADIENT_ACCUMULATION = 2  # 梯度累积步数
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50
PATCHES_PER_IMAGE = 10

BASE_CHANNELS = 24  # 减少基础通道数
USE_MIXED_PRECISION = True
TRAIN_SPLIT = 0.9

HR_DIR = './dataset_4k/high_resolution'
CHECKPOINT_DIR = Path('./checkpoints_debug')
CHECKPOINT_DIR.mkdir(exist_ok=True)

print("\n=== 训练配置 ===")
print(f"Patch: {LR_PATCH_SIZE}×{LR_PATCH_SIZE} → {HR_PATCH_SIZE}×{HR_PATCH_SIZE} ({SCALE_FACTOR}×)")
print(f"初始Batch size: {INITIAL_BATCH_SIZE}")
print(f"梯度累积: {GRADIENT_ACCUMULATION} steps")
print(f"有效Batch size: {INITIAL_BATCH_SIZE * GRADIENT_ACCUMULATION}")
print(f"基础通道数: {BASE_CHANNELS}")
print(f"混合精度: {USE_MIXED_PRECISION}")

## 6. Test Memory Requirements

In [None]:
print("\n=== 显存测试 ===")
print("\n创建模型...")
test_model = TinyUNet8x(base_ch=BASE_CHANNELS).to(device)
total_params = sum(p.numel() for p in test_model.parameters())
print(f"模型参数: {total_params:,}")
print_gpu_memory("模型加载后")

print("\n测试前向传播...")
test_batch_sizes = [1, 2, 4, 8]
max_working_batch = 1

for bs in test_batch_sizes:
    try:
        clear_gpu_memory()
        test_lr = torch.randn(bs, 3, LR_PATCH_SIZE, LR_PATCH_SIZE).to(device)
        test_hr = torch.randn(bs, 3, HR_PATCH_SIZE, HR_PATCH_SIZE).to(device)
        
        print(f"\n测试 batch_size={bs}:")
        print(f"  输入显存: {get_tensor_memory(test_lr):.1f}MB")
        print(f"  目标显存: {get_tensor_memory(test_hr):.1f}MB")
        
        # 测试前向传播
        with autocast(enabled=USE_MIXED_PRECISION):
            out = test_model(test_lr)
            loss = nn.L1Loss()(out, test_hr)
        
        alloc, _, peak = print_gpu_memory(f"前向传播 bs={bs}")
        
        # 测试反向传播
        loss.backward()
        alloc, _, peak = print_gpu_memory(f"反向传播 bs={bs}")
        
        max_working_batch = bs
        print(f"  ✓ batch_size={bs} 可用 (峰值显存: {peak:.2f}GB)")
        
        del test_lr, test_hr, out, loss
        
    except RuntimeError as e:
        if "out of memory" in str(e):
            print(f"  ✗ batch_size={bs} OOM")
            break
        else:
            raise e

clear_gpu_memory()
del test_model

print(f"\n推荐batch size: {max_working_batch}")
BATCH_SIZE = max_working_batch
print(f"使用batch size: {BATCH_SIZE}")

## 7. Data Loading

In [None]:
dataset = PatchSRDataset(
    hr_dir=HR_DIR,
    lr_patch_size=LR_PATCH_SIZE,
    scale_factor=SCALE_FACTOR,
    patches_per_image=PATCHES_PER_IMAGE,
    augment=True
)

train_size = int(TRAIN_SPLIT * len(dataset))
val_size = len(dataset) - train_size
train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(
    train_set, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=0, pin_memory=True  # num_workers=0更稳定
)
val_loader = DataLoader(
    val_set, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=0, pin_memory=True
)

print(f"\n训练集: {train_size} patches, {len(train_loader)} batches")
print(f"验证集: {val_size} patches, {len(val_loader)} batches")

## 8. Model Initialization

In [None]:
clear_gpu_memory()
print("\n初始化训练组件...")

model = TinyUNet8x(base_ch=BASE_CHANNELS).to(device)
print_gpu_memory("模型")

criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
scaler = GradScaler() if USE_MIXED_PRECISION else None

print(f"模型参数: {sum(p.numel() for p in model.parameters()):,}")

## 9. Training Functions with Memory Monitoring

In [None]:
def train_epoch(model, loader, criterion, optimizer, scaler, device, 
                use_amp, grad_accum_steps=1, verbose_memory=False):
    """
    训练一个epoch，带梯度累积和显存监控
    """
    model.train()
    running_loss = 0.0
    optimizer.zero_grad()
    
    pbar = tqdm(loader, desc='训练')
    for i, (lr, hr) in enumerate(pbar):
        lr, hr = lr.to(device), hr.to(device)
        
        # 前向传播
        if use_amp and scaler:
            with autocast(enabled=True):  # 修复：不使用device_type参数
                out = model(lr)
                loss = criterion(out, hr) / grad_accum_steps
            
            scaler.scale(loss).backward()
            
            # 梯度累积
            if (i + 1) % grad_accum_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
        else:
            out = model(lr)
            loss = criterion(out, hr) / grad_accum_steps
            loss.backward()
            
            if (i + 1) % grad_accum_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
        
        running_loss += loss.item() * grad_accum_steps
        
        # 显示显存（每100个batch）
        if verbose_memory and i % 100 == 0:
            alloc, _, peak = print_gpu_memory(f"Batch {i}")
        
        pbar.set_postfix({
            'loss': f'{loss.item() * grad_accum_steps:.4f}',
            'gpu': f'{torch.cuda.memory_allocated()/1024**3:.1f}GB'
        })
    
    return running_loss / len(loader)


def validate(model, loader, criterion, device, use_amp):
    model.eval()
    running_loss = 0.0
    
    with torch.no_grad():
        for lr, hr in tqdm(loader, desc='验证'):
            lr, hr = lr.to(device), hr.to(device)
            
            if use_amp:
                with autocast(enabled=True):  # 修复
                    out = model(lr)
                    loss = criterion(out, hr)
            else:
                out = model(lr)
                loss = criterion(out, hr)
            
            running_loss += loss.item()
    
    return running_loss / len(loader)

## 10. Training Loop

In [None]:
history = {'train_loss': [], 'val_loss': []}
best_val_loss = float('inf')

print("\n" + "="*60)
print("开始训练")
print("="*60)

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 50)
    
    # 第一个epoch显示详细显存
    verbose = (epoch == 0)
    
    try:
        train_loss = train_epoch(
            model, train_loader, criterion, optimizer, scaler, device,
            USE_MIXED_PRECISION, GRADIENT_ACCUMULATION, verbose_memory=verbose
        )
        
        val_loss = validate(model, val_loader, criterion, device, USE_MIXED_PRECISION)
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        scheduler.step()
        
        print(f"\n训练损失: {train_loss:.6f}")
        print(f"验证损失: {val_loss:.6f}")
        print(f"学习率: {scheduler.get_last_lr()[0]:.6f}")
        print_gpu_memory(f"Epoch {epoch+1} 结束")
        
        # 保存最佳模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'config': {
                    'lr_patch': LR_PATCH_SIZE,
                    'hr_patch': HR_PATCH_SIZE,
                    'scale': SCALE_FACTOR,
                    'base_ch': BASE_CHANNELS
                }
            }, CHECKPOINT_DIR / 'best_model.pth')
            print(f"✓ 保存最佳模型")
        
        # 定期保存
        if (epoch + 1) % 10 == 0:
            torch.save(model.state_dict(), CHECKPOINT_DIR / f'epoch_{epoch+1}.pth')
            print(f"✓ 保存检查点")
        
        # 清理显存
        if (epoch + 1) % 5 == 0:
            clear_gpu_memory()
    
    except RuntimeError as e:
        if "out of memory" in str(e):
            print(f"\n!!! OOM错误 !!!")
            print_gpu_memory("OOM时")
            print("\n建议:")
            print(f"  1. 减小batch size (当前: {BATCH_SIZE})")
            print(f"  2. 增加梯度累积 (当前: {GRADIENT_ACCUMULATION})")
            print(f"  3. 减小BASE_CHANNELS (当前: {BASE_CHANNELS})")
            print(f"  4. 减小patch size (当前: {LR_PATCH_SIZE}→{HR_PATCH_SIZE})")
            raise e
        else:
            raise e

print(f"\n训练完成！最佳验证损失: {best_val_loss:.6f}")

## 11. Plot Training Curve

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], label='Training', marker='o')
plt.plot(history['val_loss'], label='Validation', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title(f'{SCALE_FACTOR}× SR Training ({LR_PATCH_SIZE}→{HR_PATCH_SIZE})')
plt.legend()
plt.grid(True)
plt.savefig(CHECKPOINT_DIR / 'curve.png', dpi=150)
plt.show()

## 12. Training Summary

In [None]:
_, _, peak_memory = print_gpu_memory("训练结束")

summary = f"""
训练总结
{'='*60}

配置:
  Patch: {LR_PATCH_SIZE}×{LR_PATCH_SIZE} → {HR_PATCH_SIZE}×{HR_PATCH_SIZE} ({SCALE_FACTOR}×)
  Batch size: {BATCH_SIZE}
  梯度累积: {GRADIENT_ACCUMULATION}
  有效batch: {BATCH_SIZE * GRADIENT_ACCUMULATION}
  基础通道: {BASE_CHANNELS}
  模型参数: {sum(p.numel() for p in model.parameters()):,}

显存使用:
  峰值显存: {peak_memory:.2f} GB

结果:
  最佳验证损失: {best_val_loss:.6f}
  总训练轮数: {len(history['train_loss'])}

模型保存: {CHECKPOINT_DIR / 'best_model.pth'}
"""

print(summary)

with open(CHECKPOINT_DIR / 'summary.txt', 'w') as f:
    f.write(summary)