<a href="https://colab.research.google.com/github/jimmy-pink/colab-playground/blob/main/pre-trained/ResNet50V2-FolderIconRecognition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 使用ResNet50V2微调以解决FolderIcon二分类问题

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, Input
from tensorflow.keras.optimizers import Adam

### 数据准备

In [None]:
## 转存 https://drive.google.com/drive/folders/1xwtf91GSyeIc7ohpCKsDCYv3zXgKa0sf
from google.colab import drive
drive.mount('/content/drive')

# 挂载 Google Drive
base_dir = '/content/drive/MyDrive/Google AI Studio/data/folder-icon-images/'  # Google Drive 中的图像文件夹路径
train_dir=base_dir+'is_folder'
drive_train_validate_dir = base_dir + "train_validate"

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# 设置数据增强
train_datagen = ImageDataGenerator(rescale=1./255,
                                   horizontal_flip=True)

# 使用 flow_from_directory 加载训练数据
# 数据增强 + 归一化
train_datagen = ImageDataGenerator(
    rescale=1./255,
    horizontal_flip=True,
    validation_split=0.3  # 30% 作为验证集
)

# 训练集生成器
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='binary',  # 二分类用 binary，多分类用 categorical
    subset='training'  # 指定是训练集
)

# 验证集生成器
validation_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode='binary',
    subset='validation'  # 指定是验证集
)
print(f"找到的训练样本数: {train_generator.samples}")
print(f"找到的验证样本数: {validation_generator.samples}")
images, labels = next(train_generator)
print("图像形状:", images.shape)
print("标签形状:", labels.shape)

In [None]:
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
# 获取训练集的真实标签（需确保 train_generator.classes 是整数标签 0/1）
y_train = train_generator.classes
# 计算类别权重（classes 需是 NumPy 数组）
class_weights = compute_class_weight(
    'balanced',
    classes=np.unique(y_train),  # 自动提取唯一类别（如 [0, 1]）
    y=y_train
)
# 转换为字典格式
class_weights = {i: weight for i, weight in enumerate(class_weights)}
print("类别权重:", class_weights)

### 建模

In [None]:
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras import layers, models

base_model = ResNet50V2(
    weights='imagenet',  # 使用预训练权重
    include_top=False,   # 不包含顶层分类器
    input_shape=(224, 224, 3)
)


# 冻结前 n% 的层
freeze_until = int(len(base_model.layers) * 0.9)
for i, layer in enumerate(base_model.layers):
    layer.trainable = i >= freeze_until

In [None]:
# 可选： 模型保存
import os
save_dir = '/content/drive/MyDrive/Google AI Studio/saved_models'
# 如果目录不存在，则创建
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
model_file_save_path = f'{save_dir}/ResNet50V2_folder_icon_shape_predict_model.h5'

In [None]:
model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dropout(0.6),
    layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.05)),
    layers.Dropout(0.5),
    layers.Dense(1, activation='sigmoid')  # 二分类输出
])

optimizer = AdamW(learning_rate=1e-4, weight_decay=1e-4)
model.compile(
    optimizer=optimizer,
    loss='binary_crossentropy',
    metrics=['accuracy']
)

In [None]:
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

class myCallback(Callback):
    def on_epoch_end(self, epoch, logs={}):
        if(logs.get('val_accuracy') >= 0.98 and logs.get('val_loss') < 0.2 ):
            self.model.stop_training = True

callbacks = [
    ModelCheckpoint(
        model_file_save_path,
        monitor='val_accuracy',  # 监控验证准确率
        save_best_only=True,     # 只保留最佳
        mode='max'               # 取最大值
    ),
    EarlyStopping(monitor='val_loss', patience=5),
    myCallback()
]

history = model.fit(
    train_generator,
    epochs=50,
    validation_data=validation_generator,
    class_weight=class_weights,
          verbose=1,
          callbacks=callbacks
)