In [None]:
# 垃圾分类模型训练

本notebook用于训练垃圾分类模型并转换为TensorFlow Lite格式，以便在Android应用中使用。


In [None]:
# 导入必要的库

import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dropout, Flatten, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

In [None]:
# 设置参数
IMAGE_SIZE = 224  # 图像大小
BATCH_SIZE = 32   # 批量大小
EPOCHS = 20       # 训练轮数
NUM_CLASSES = 6   # 分类数量：cardboard, glass, metal, paper, plastic, trash

# 数据集路径
DATASET_PATH = "Garbage classification/Garbage classification"

# 确认数据集路径存在
if not os.path.exists(DATASET_PATH):
    raise Exception(f"数据集路径不存在: {DATASET_PATH}")

# 列出所有分类
categories = os.listdir(DATASET_PATH)
print(f"分类类别: {categories}")
print(f"分类数量: {len(categories)}")


In [None]:
# 数据增强和预处理
train_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2,  # 20%数据用于验证
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# 仅进行缩放的验证数据生成器
validation_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2
)

# 训练数据生成器
train_generator = train_datagen.flow_from_directory(
    DATASET_PATH,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='training',
    shuffle=True
)

# 验证数据生成器
validation_generator = validation_datagen.flow_from_directory(
    DATASET_PATH,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation',
    shuffle=False
)

# 获取类别索引映射
class_indices = train_generator.class_indices
print("类别索引映射:")
print(class_indices)

# 反转映射，用于之后的标签解码
class_names = {v: k for k, v in class_indices.items()}

# 保存类别名称，用于Android应用
with open('class_names.txt', 'w') as f:
    for i in range(len(class_names)):
        f.write(f"{class_names[i]}\n")


In [None]:
# 构建模型
model = Sequential([
    # 第一个卷积块
    Conv2D(32, (3, 3), activation='relu', input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3)),
    MaxPooling2D(2, 2),
    
    # 第二个卷积块
    Conv2D(64, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    
    # 第三个卷积块
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    
    # 第四个卷积块
    Conv2D(128, (3, 3), activation='relu'),
    MaxPooling2D(2, 2),
    
    # 展平层
    Flatten(),
    
    # 全连接层
    Dense(512, activation='relu'),
    Dropout(0.5),  # 防止过拟合
    Dense(NUM_CLASSES, activation='softmax')  # 输出层，6个分类
])

# 编译模型
model.compile(
    optimizer=Adam(learning_rate=0.0001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# 打印模型摘要
model.summary()


In [None]:
# 设置回调函数
checkpoint = ModelCheckpoint(
    'garbage_classification_model_best.h5',
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True,
    verbose=1
)

# 训练模型
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // BATCH_SIZE,
    callbacks=[checkpoint, early_stopping]
)


In [None]:
# 绘制训练过程
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(acc, label='训练准确率')
plt.plot(val_acc, label='验证准确率')
plt.legend()
plt.title('准确率')

plt.subplot(1, 2, 2)
plt.plot(loss, label='训练损失')
plt.plot(val_loss, label='验证损失')
plt.legend()
plt.title('损失')

plt.tight_layout()
plt.show()


In [None]:
# 评估模型
evaluation = model.evaluate(validation_generator)
print(f"验证损失: {evaluation[0]:.4f}")
print(f"验证准确率: {evaluation[1]:.4f}")


In [None]:
# 保存模型
model.save('garbage_classification_model.h5')
print("模型已保存为 garbage_classification_model.h5")


In [None]:
# 转换为TensorFlow Lite格式
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# 保存TFLite模型
with open('garbage_classification_model.tflite', 'wb') as f:
    f.write(tflite_model)
    
print("TensorFlow Lite模型已保存为 garbage_classification_model.tflite")


In [None]:
# 量化模型（减小模型大小，适合移动设备）
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()

# 保存量化后的TFLite模型
with open('garbage_classification_model_quantized.tflite', 'wb') as f:
    f.write(tflite_quant_model)
    
print("量化后的TensorFlow Lite模型已保存为 garbage_classification_model_quantized.tflite")
print(f"原始模型大小: {len(tflite_model) / 1024:.2f} KB")
print(f"量化后模型大小: {len(tflite_quant_model) / 1024:.2f} KB")
