In [45]:
from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2
import ssl

ssl._create_default_https_context = ssl._create_unverified_context

In [46]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


In [47]:
image_size = (180, 180)
batch_size = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "PetImages",
    validation_split=0.2,
    subset="training",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "PetImages",
    validation_split=0.2,
    subset="validation",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
)


Found 23410 files belonging to 2 classes.
Using 18728 files for training.
Found 23410 files belonging to 2 classes.
Using 4682 files for validation.


In [48]:
pt_model =InceptionResNetV2(weights='imagenet')
pt_model.summary()

Model: "inception_resnet_v2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_10 (InputLayer)           [(None, 299, 299, 3) 0                                            
__________________________________________________________________________________________________
conv2d_812 (Conv2D)             (None, 149, 149, 32) 864         input_10[0][0]                   
__________________________________________________________________________________________________
batch_normalization_812 (BatchN (None, 149, 149, 32) 96          conv2d_812[0][0]                 
__________________________________________________________________________________________________
activation_1448 (Activation)    (None, 149, 149, 32) 0           batch_normalization_812[0][0]    
________________________________________________________________________________

In [49]:
pt_model = InceptionResNetV2(weights='imagenet', include_top=False, input_shape=image_size + (3,))
pt_model.summary()

Model: "inception_resnet_v2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_11 (InputLayer)           [(None, 180, 180, 3) 0                                            
__________________________________________________________________________________________________
conv2d_1015 (Conv2D)            (None, 89, 89, 32)   864         input_11[0][0]                   
__________________________________________________________________________________________________
batch_normalization_1015 (Batch (None, 89, 89, 32)   96          conv2d_1015[0][0]                
__________________________________________________________________________________________________
activation_1651 (Activation)    (None, 89, 89, 32)   0           batch_normalization_1015[0][0]   
________________________________________________________________________________

In [50]:
from tensorflow.keras.applications.inception_resnet_v2 import preprocess_input 
# from tensorflow.keras.applications.xception 


def make_model(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)
    # Image augmentation block
    x = preprocess_input(inputs)
    x = pt_model(x)  # 把處理好的output 來當這邊的input
    
    x = layers.GlobalAveragePooling2D()(x)

    if num_classes == 2:
        activation = "sigmoid"
        units = 1
    else:
        activation = "softmax"
        units = num_classes

    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(units, activation=activation)(x)
    return keras.Model(inputs, outputs)

In [51]:
pt_model = make_model(input_shape=image_size + (3,), num_classes=2)
# keras.utils.plot_model(model, show_shapes=True)
pt_model.summary()


Model: "functional_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_12 (InputLayer)        [(None, 180, 180, 3)]     0         
_________________________________________________________________
tf_op_layer_RealDiv_2 (Tenso [(None, 180, 180, 3)]     0         
_________________________________________________________________
tf_op_layer_Sub_2 (TensorFlo [(None, 180, 180, 3)]     0         
_________________________________________________________________
inception_resnet_v2 (Functio (None, 4, 4, 1536)        54336736  
_________________________________________________________________
global_average_pooling2d_5 ( (None, 1536)              0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 1536)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 1)                

In [52]:
epochs = 3

callbacks = [
    keras.callbacks.ModelCheckpoint("save_at_{epoch}.h6"),
]
pt_model.compile(
    optimizer=keras.optimizers.Adam(1e-3),
    loss="binary_crossentropy",
    metrics=["accuracy"],
)

for layer in pt_model.layers[:-1]:
   layer.trainable = False
for layer in pt_model.layers[-1:]:
   layer.trainable = True

pt_model.fit(
    train_ds, epochs=epochs, callbacks=callbacks, validation_data=val_ds,
)

Epoch 1/3

KeyboardInterrupt: 