In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, regularizers
import numpy as np

tf.random.set_seed(42)
np.random.seed(42)

# ----------------------------
# 1. 数据加载与预处理
# ----------------------------
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# 归一化到 [0, 1]
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# 转为 one-hot 编码
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

# 数据增强
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    zoom_range=0.1
)
datagen.fit(x_train)

# ----------------------------
# 2. 构建 ResNet-18 模型（适配 CIFAR-10）
# ----------------------------
def residual_block(x, filters, kernel_size=3, stride=1, conv_shortcut=False):
    shortcut = x
    if conv_shortcut:
        shortcut = layers.Conv2D(filters, 1, strides=stride, use_bias=False)(x)
        shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Conv2D(filters, kernel_size, strides=stride, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2D(filters, kernel_size, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Add()([x, shortcut])
    x = layers.ReLU()(x)
    return x

def ResNet18(input_shape=(32, 32, 3), num_classes=10):
    inputs = layers.Input(shape=input_shape)

    # 初始卷积
    x = layers.Conv2D(64, 3, strides=1, padding='same', use_bias=False)(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # ResNet blocks
    x = residual_block(x, 64, conv_shortcut=True)
    x = residual_block(x, 64)

    x = residual_block(x, 128, stride=2, conv_shortcut=True)
    x = residual_block(x, 128)

    x = residual_block(x, 256, stride=2, conv_shortcut=True)
    x = residual_block(x, 256)

    x = residual_block(x, 512, stride=2, conv_shortcut=True)
    x = residual_block(x, 512)

    # 全局平均池化 + 分类头
    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    return Model(inputs, outputs)

model = ResNet18()

# ----------------------------
# 3. 编译模型
# ----------------------------
# 使用带 warmup 和余弦退火的学习率调度（可选高级技巧）

model.compile(
    optimizer=tf.keras.optimizers.Adam(lr=0.0005),
    loss='categorical_crossentropy',
    metrics=['accuracy'],
)

# ----------------------------
# 4. 训练模型
# ----------------------------

# 使用数据增强生成器
early_stop = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',      # 监控验证损失
    patience=5,             # 连续 10 个 epoch 没有改善就停止
    restore_best_weights=True  # 自动恢复验证集上表现最好的权重
)

# 配合 ReduceLROnPlateau，在 loss 不下降时降低学习率
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=3,
    verbose=1
)
history = model.fit(
    datagen.flow(x_train, y_train, batch_size=128),
    epochs=50,
    validation_data=(x_test, y_test),
    verbose=1,
    callbacks=[early_stop, reduce_lr],
)

# ----------------------------
# 5. 评估
# ----------------------------
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
print(f"\nTest Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")

  super().__init__(name, **kwargs)


Epoch 1/50


In [None]:
# 保存模型
model.save("my_model.h5")

In [1]:
import numpy as np
import tensorflow as tf
from PIL import Image
import os

# 创建保存目录
output_dir = "cifar10_test_images_tf"
os.makedirs(output_dir, exist_ok=True)

# CIFAR-10 类别名称
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

# 加载 CIFAR-10 数据集
(_, _), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()

# 保存前 N 张图像
N = 10
for i in range(N):
    # 获取图像和标签
    image = test_images[i]
    label = test_labels[i][0]  # test_labels 形状为 (num_samples, 1)
    class_name = classes[label]

    # 调整图像数据格式并保存
    img = Image.fromarray(image)
    filename = f"{i:02d}_{class_name}.png"
    img.save(os.path.join(output_dir, filename))

print(f"已保存 {N} 张测试图像到 '{output_dir}' 文件夹。")

已保存 10 张测试图像到 'cifar10_test_images_tf' 文件夹。
