## 自动加载数据集

In [1]:
import tensorflow as tf
import os

# 定义数据集路径
dataset_path = "../dataset/Garbage classification/Garbage classification"  # 根据你的实际路径调整
img_size = (224, 224)  # 目标图像大小
batch_size = 32  # 你可以调整 batch_size,

# 直接从文件夹中加载数据，并自动划分训练/验证集
train_dataset = tf.keras.utils.image_dataset_from_directory(
    dataset_path,
    validation_split=0.2,  # 20% 作为验证集
    subset="training",
    seed=123,  # 设定随机种子，保证每次划分一致
    image_size=img_size,
    batch_size=batch_size
)

val_dataset = tf.keras.utils.image_dataset_from_directory(
    dataset_path,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=img_size,
    batch_size=batch_size
)

# 打印类别索引映射
class_names = train_dataset.class_names
print(f"类别索引映射: {class_names}")

# 预处理（标准化）
normalization_layer = tf.keras.layers.Rescaling(1./255)
train_dataset = train_dataset.map(lambda x, y: (normalization_layer(x), y))
val_dataset = val_dataset.map(lambda x, y: (normalization_layer(x), y))


Found 2527 files belonging to 6 classes.
Using 2022 files for training.
Found 2527 files belonging to 6 classes.
Using 505 files for validation.
类别索引映射: ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']


## 使用 MobileNetV2 进行训练

In [2]:
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model

# 加载预训练模型
base_model = MobileNetV2(input_shape=(224, 224, 3), include_top=False, weights="imagenet")
base_model.trainable = False  # 冻结预训练权重

# 添加分类层
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(128, activation="relu")(x)
output = Dense(len(class_names), activation="softmax")(x)

# 构建模型
model = Model(inputs=base_model.input, outputs=output)
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

# 训练模型
history = model.fit(train_dataset, validation_data=val_dataset, epochs=10)

Epoch 1/10
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 610ms/step - accuracy: 0.5736 - loss: 1.1452 - val_accuracy: 0.7960 - val_loss: 0.5522
Epoch 2/10
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 611ms/step - accuracy: 0.8472 - loss: 0.4590 - val_accuracy: 0.7644 - val_loss: 0.5918
Epoch 3/10
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 605ms/step - accuracy: 0.8952 - loss: 0.3195 - val_accuracy: 0.8198 - val_loss: 0.4795
Epoch 4/10
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 604ms/step - accuracy: 0.9355 - loss: 0.2255 - val_accuracy: 0.8337 - val_loss: 0.4976
Epoch 5/10
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 607ms/step - accuracy: 0.9588 - loss: 0.1583 - val_accuracy: 0.8297 - val_loss: 0.5032
Epoch 6/10
[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 624ms/step - accuracy: 0.9784 - loss: 0.1033 - val_accuracy: 0.8337 - val_loss: 0.5062
Epoch 7/10
[1m64/64[