In [147]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import cifar10
import numpy as np


def create_pretrained_model(input_shape=(32, 32, 3), num_classes=1):
    # Load the ResNet50 model pre-trained on ImageNet, excluding the top layers
    base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=(32, 32, 3))

    # Add custom layers on top of the base model
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(128, activation='relu')(x)
    x = Dense(64, activation='relu')(x)
    predictions = Dense(1, activation='sigmoid')(x)

    # Create the full model
    model = Model(inputs=base_model.input, outputs=predictions)

    # Freeze the base model layers (optional)
    for layer in base_model.layers:
        layer.trainable = False

    # Compile the model
    return model

# Load CIFAR-10 data
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Create binary labels for "frog" (class 6) vs. "no frog" (all other classes)
y_train_binary = np.where(y_train == 6, 1, 0)
y_test_binary = np.where(y_test == 6, 1, 0)

# Normalize images
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Define the image data generator for augmentation
datagen = ImageDataGenerator(
    # rotation_range=20,
    # width_shift_range=0.2,
    # height_shift_range=0.2,
    # horizontal_flip=True
)

datagen.fit(x_train)


# Create the full model
model = create_pretrained_model()

# Compile the model
model.compile(optimizer=Adam(learning_rate=0.001), loss='binary_crossentropy', metrics=['accuracy'])

# Train the model
history = model.fit(
    datagen.flow(x_train, y_train_binary, batch_size=32),
    steps_per_epoch=len(x_train) // 32,
    epochs=10,
    validation_data=(x_test, y_test_binary)
)


# Evaluate the model
loss, accuracy = model.evaluate(x_test, y_test_binary, verbose=0)
print(f'Test accuracy: {accuracy * 100:.2f}%')


Epoch 1/10
Epoch 2/10

KeyboardInterrupt: 