In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt
import os

# 数据生成（增强和预处理）
def generate_synthetic_data(num_samples=1000, image_size=(64, 64)):
    X = np.random.rand(num_samples, *image_size)
    y = np.random.rand(num_samples, *image_size)

    # 数据增强（旋转、平移、加噪声等）
    X = np.array([np.rot90(x, np.random.randint(0, 4)) for x in X])
    y = np.array([np.rot90(y_, np.random.randint(0, 4)) for y_ in y])
    X += np.random.normal(0, 0.01, X.shape)
    return X, y

# U-Net结构，支持可变规模
def build_unet(input_shape=(64, 64, 1), base_filters=32, depth=3):
    inputs = layers.Input(shape=input_shape)
    x = inputs
    skips = []
    # 编码器
    for i in range(depth):
        x = layers.Conv2D(base_filters * 2**i, (3, 3), activation='relu', padding='same')(x)
        x = layers.Conv2D(base_filters * 2**i, (3, 3), activation='relu', padding='same')(x)
        skips.append(x)
        x = layers.MaxPooling2D((2, 2))(x)
    # Bottleneck
    x = layers.Conv2D(base_filters * 2**depth, (3, 3), activation='relu', padding='same')(x)
    x = layers.Conv2D(base_filters * 2**depth, (3, 3), activation='relu', padding='same')(x)
    # 解码器
    for i in reversed(range(depth)):
        x = layers.UpSampling2D((2, 2))(x)
        x = layers.Concatenate()([x, skips[i]])
        x = layers.Conv2D(base_filters * 2**i, (3, 3), activation='relu', padding='same')(x)
        x = layers.Conv2D(base_filters * 2**i, (3, 3), activation='relu', padding='same')(x)
    outputs = layers.Conv2D(1, (1, 1), activation='linear')(x)
    model = models.Model(inputs, outputs)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse', metrics=['mse', 'mae'])
    return model

# 生成数据并进行训练
X, y = generate_synthetic_data()

# 数据维度调整
X = X[..., np.newaxis]
y = y[..., np.newaxis]

# 不同规模的模型参数
model_configs = [
    {'name': 'Small U-Net', 'base_filters': 16, 'depth': 2},
    {'name': 'Medium U-Net', 'base_filters': 32, 'depth': 3},
    {'name': 'Large U-Net', 'base_filters': 64, 'depth': 4},
]

results = []

result_dir = os.path.join(os.path.dirname(__file__), 'results')
os.makedirs(result_dir, exist_ok=True)

for config in model_configs:
    print(f"\n训练模型: {config['name']}")
    model = build_unet(input_shape=(64, 64, 1), base_filters=config['base_filters'], depth=config['depth'])
    early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
    history = model.fit(X, y, epochs=60, batch_size=32, validation_split=0.2, callbacks=[early_stopping], verbose=0)
    val_mse = history.history['val_mse'][-1]
    val_mae = history.history['val_mae'][-1]
    results.append({'name': config['name'], 'val_mse': val_mse, 'val_mae': val_mae, 'history': history})
    print(f"最终验证集MSE: {val_mse:.4f}, MAE: {val_mae:.4f}")
    # 保存每个模型的损失曲线
    plt.figure()
    plt.plot(history.history['val_loss'], label=f"{config['name']} Val Loss")
    plt.xlabel('Epoch')
    plt.ylabel('Val MSE Loss')
    plt.legend()
    plt.title(f'{config['name']} 验证损失曲线')
    plt.savefig(os.path.join(result_dir, f"{config['name'].replace(' ', '_')}_val_loss.png"))
    plt.close()

# 绘制不同规模模型的损失曲线
plt.figure(figsize=(10, 6))
for res in results:
    plt.plot(res['history'].history['val_loss'], label=f"{res['name']} Val Loss")
plt.xlabel('Epoch')
plt.ylabel('Val MSE Loss')
plt.legend()
plt.title('不同规模U-Net模型验证损失对比')
plt.savefig(os.path.join(result_dir, "all_models_val_loss.png"))
plt.close()

# 随机选取样本进行对比可视化（用最大规模模型）
idx = np.random.randint(0, X.shape[0])
sample_input = X[idx:idx+1]
sample_true = y[idx]
best_model = build_unet(input_shape=(64, 64, 1), base_filters=64, depth=4)
best_model.fit(X, y, epochs=60, batch_size=32, validation_split=0.2, callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)], verbose=0)
sample_pred = best_model.predict(sample_input)[0]

plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.title('Input Data')
plt.imshow(sample_input[0, :, :, 0], cmap='gray')
plt.subplot(1, 3, 2)
plt.title('True Property')
plt.imshow(sample_true[:, :, 0], cmap='gray')
plt.subplot(1, 3, 3)
plt.title('Predicted Property (Large U-Net)')
plt.imshow(sample_pred[:, :, 0], cmap='gray')
plt.savefig(os.path.join(result_dir, "comparison_large_unet.png"))
plt.close()

# 保存所有模型验证集MSE和MAE到txt
with open(os.path.join(result_dir, "metrics.txt"), "w") as f:
    for res in results:
        f.write(f"{res['name']}: MSE={res['val_mse']:.4f}, MAE={res['val_mae']:.4f}\n")