In [1]:
import tensorflow as tf
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Reshape, Multiply, Input
from tensorflow.keras.models import Model
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping


In [2]:
def se_block(input_tensor, reduction=16):
    channel = input_tensor.shape[-1]
    se = GlobalAveragePooling2D()(input_tensor)
    se = Dense(channel // reduction, activation='relu')(se)
    se = Dense(channel, activation='sigmoid')(se)
    se = Reshape([1, 1, channel])(se)
    x = Multiply()([input_tensor, se])
    return x

def create_resnet50(input_shape, num_classes):
    base_model = ResNet50(weights=None, include_top=False, input_shape=input_shape)
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    output = Dense(num_classes, activation='softmax')(x)
    model = Model(inputs=base_model.input, outputs=output)
    return model

def create_se_resnet50(input_shape, num_classes):
    base_model = ResNet50(weights=None, include_top=False, input_shape=input_shape)
    x = base_model.output

    # Áp dụng SE Module cho mỗi block
    for layer in base_model.layers:
        if 'conv' in layer.name:
            x = se_block(layer.output)

    x = GlobalAveragePooling2D()(x)
    output = Dense(num_classes, activation='softmax')(x)
    model = Model(inputs=base_model.input, outputs=output)
    return model

In [3]:
def train_and_evaluate(model, X_train, y_train, X_test, y_test, batch_size=64, epochs=100):
    model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])
    early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

    history = model.fit(X_train, y_train,
                        validation_data=(X_test, y_test),
                        epochs=epochs,
                        batch_size=batch_size,
                        callbacks=[early_stopping])

    return history, model

# Tải dữ liệu CIFAR-10
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0
y_train, y_test = to_categorical(y_train), to_categorical(y_test)

input_shape = X_train.shape[1:]
num_classes = y_train.shape[-1]

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [4]:
# Tạo và huấn luyện ResNet-50
resnet50_model = create_resnet50(input_shape, num_classes)
history_resnet50, resnet50_model = train_and_evaluate(resnet50_model, X_train, y_train, X_test, y_test)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100


In [5]:
# Tạo và huấn luyện SE-ResNet-50
se_resnet50_model = create_se_resnet50(input_shape, num_classes)
history_se_resnet50, se_resnet50_model = train_and_evaluate(se_resnet50_model, X_train, y_train, X_test, y_test)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100


In [6]:
# Đánh giá mô hình
resnet50_eval = resnet50_model.evaluate(X_test, y_test)
se_resnet50_eval = se_resnet50_model.evaluate(X_test, y_test)

print("ResNet-50 Accuracy: ", resnet50_eval[1])
print("SE-ResNet-50 Accuracy: ", se_resnet50_eval[1])

ResNet-50 Accuracy:  0.7089999914169312
SE-ResNet-50 Accuracy:  0.7138000130653381
