# 3.0 训练调试

本 notebook 用于调试 Inpainting 网络的训练过程。

## 目标
1. 验证数据加载
2. 测试网络前向传播
3. 监控训练过程

In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# 检查 PyTorch
try:
    import torch
    print(f"PyTorch 版本: {torch.__version__}")
    print(f"CUDA 可用: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
except ImportError:
    print("PyTorch 未安装")

## 1. 测试网络架构

In [None]:
from src.texture_synthesis.network import InpaintingUNet, PatchDiscriminator, count_parameters

# 创建模型
generator = InpaintingUNet()
discriminator = PatchDiscriminator()

print(f"Generator 参数量: {count_parameters(generator):,}")
print(f"Discriminator 参数量: {count_parameters(discriminator):,}")

In [None]:
# 测试前向传播
x = torch.randn(2, 1, 64, 64, 64)

with torch.no_grad():
    g_out = generator(x)
    d_out = discriminator(x)

print(f"Generator 输入: {x.shape}")
print(f"Generator 输出: {g_out.shape}")
print(f"Discriminator 输出: {d_out.shape}")

## 2. 测试损失函数

In [None]:
from src.texture_synthesis.losses import InpaintingLoss

criterion = InpaintingLoss()

# 创建测试数据
pred = torch.randn(2, 1, 32, 32, 32)
target = torch.randn(2, 1, 32, 32, 32)
mask = torch.zeros(2, 1, 32, 32, 32)
mask[:, :, 10:20, 10:20, 10:20] = 1

# 计算损失
losses = criterion.generator_loss(pred, target, mask)

for name, value in losses.items():
    print(f"{name}: {value.item():.4f}")

## 3. 创建合成数据进行训练测试

In [None]:
def create_synthetic_batch(batch_size=4, size=32):
    """创建合成训练数据"""
    # 创建目标（模拟肺部纹理）
    target = torch.randn(batch_size, 1, size, size, size) * 0.1 + 0.5
    target = torch.clamp(target, 0, 1)
    
    # 创建 mask（随机位置的球形区域）
    mask = torch.zeros(batch_size, 1, size, size, size)
    for b in range(batch_size):
        cx, cy, cz = np.random.randint(8, size-8, 3)
        r = np.random.randint(3, 6)
        for x in range(size):
            for y in range(size):
                for z in range(size):
                    if (x-cx)**2 + (y-cy)**2 + (z-cz)**2 <= r**2:
                        mask[b, 0, x, y, z] = 1
    
    # 创建输入（mask 区域置零）
    input_data = target.clone()
    input_data[mask > 0] = 0
    
    return input_data, target, mask

In [None]:
# 创建测试批次
input_data, target, mask = create_synthetic_batch()

print(f"Input shape: {input_data.shape}")
print(f"Target shape: {target.shape}")
print(f"Mask shape: {mask.shape}")
print(f"Mask ratio: {mask.mean().item():.2%}")

## 4. 简单训练循环测试

In [None]:
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

# 创建模型
generator = InpaintingUNet().to(device)
optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
criterion = InpaintingLoss()

# 训练几个 step
losses = []
for step in range(20):
    input_data, target, mask = create_synthetic_batch()
    input_data = input_data.to(device)
    target = target.to(device)
    mask = mask.to(device)
    
    optimizer.zero_grad()
    pred = generator(input_data)
    loss_dict = criterion.generator_loss(pred, target, mask)
    loss = loss_dict['total']
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    if step % 5 == 0:
        print(f"Step {step}: loss = {loss.item():.4f}")

In [None]:
# 绘制损失曲线
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

## 5. 可视化预测结果

In [None]:
# 生成预测
generator.eval()
with torch.no_grad():
    input_data, target, mask = create_synthetic_batch(batch_size=1)
    input_data = input_data.to(device)
    pred = generator(input_data)

# 可视化
slice_idx = 16
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

axes[0].imshow(input_data[0, 0, slice_idx].cpu(), cmap='gray')
axes[0].set_title('Input (with hole)')

axes[1].imshow(target[0, 0, slice_idx], cmap='gray')
axes[1].set_title('Target')

axes[2].imshow(pred[0, 0, slice_idx].cpu(), cmap='gray')
axes[2].set_title('Prediction')

axes[3].imshow(mask[0, 0, slice_idx], cmap='gray')
axes[3].set_title('Mask')

for ax in axes:
    ax.axis('off')

plt.tight_layout()
plt.show()