# Experiment 2
- Load the saved model and replace only the output layer of the model (to align it to the new problem).
- Train and evaluate the model (for 50 epochs) on the cats and dogs dataset.
### Setup

In [13]:
import tensorflow as tf
from keras import layers
import numpy as np
import keras
from keras import layers
from tensorflow import data as tf_data
import matplotlib.pyplot as plt
import os

### Load Pretrained model

In [14]:
model = keras.models.load_model("models/experiment1_model.keras")

### Change the output layer
Go from multi class classification to binary classification.  

In [15]:
# Remove the last layer and freeze the base model
base_model = keras.Model(inputs=model.input, outputs=model.layers[-2].output)
base_model.trainable = False  # Freeze pretrained layers

# Add new output layer (Binary Classification)
new_output = layers.Dense(1, activation="sigmoid")(base_model.output)

# Create a new model
new_model = keras.Model(inputs=base_model.input, outputs=new_output)

### Compile the model

In [16]:
new_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.0001),
    loss=keras.losses.BinaryCrossentropy(),
    metrics=[keras.metrics.BinaryAccuracy(name="acc")],
)

### Create dataset

In [17]:
# Clean the dataset
num_skipped = 0
for folder_name in ("Cat", "Dog"):
    folder_path = os.path.join("PetImages", folder_name)
    for fname in os.listdir(folder_path):
        fpath = os.path.join(folder_path, fname)
        try:
            fobj = open(fpath, "rb")
            is_jfif = b"JFIF" in fobj.peek(10)
        finally:
            fobj.close()

        if not is_jfif:
            num_skipped += 1
            # Delete corrupted image
            os.remove(fpath)

print(f"Deleted {num_skipped} images.")

Deleted 1578 images.


In [18]:
image_size = (180, 180)
batch_size = 128

train_ds, val_ds = keras.utils.image_dataset_from_directory(
    "PetImages",
    validation_split=0.2,
    subset="both",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
)

Found 23422 files belonging to 2 classes.
Using 18738 files for training.
Using 4684 files for validation.


In [20]:
data_augmentation_layers = [
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
]


def data_augmentation(images):
    for layer in data_augmentation_layers:
        images = layer(images)
    return images

In [21]:
augmented_train_ds = train_ds.map(
    lambda x, y: (data_augmentation(x), y))

In [None]:
# Apply `data_augmentation` to the training images.
train_ds = train_ds.map(
    lambda img, label: (data_augmentation(img), label),
    num_parallel_calls=tf_data.AUTOTUNE,
)
# Prefetching samples in GPU memory helps maximize GPU utilization.
train_ds = train_ds.prefetch(tf_data.AUTOTUNE)
val_ds = val_ds.prefetch(tf_data.AUTOTUNE)

### Train

In [24]:
epochs = 50
callbacks = [
    keras.callbacks.ModelCheckpoint("models/experiment2_epoch_{epoch}.keras"),
]

new_model.fit(
    train_ds,
    epochs=epochs,
    validation_data=val_ds,
    callbacks=callbacks
)

new_model.save("experiment2_model.keras")

Epoch 1/50
[1m147/147[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m255s[0m 2s/step - acc: 0.5615 - loss: 0.6993 - val_acc: 0.6838 - val_loss: 0.6107
Epoch 2/50
[1m147/147[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m251s[0m 2s/step - acc: 0.6534 - loss: 0.6241 - val_acc: 0.7054 - val_loss: 0.5830
Epoch 3/50
[1m147/147[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m251s[0m 2s/step - acc: 0.6784 - loss: 0.6016 - val_acc: 0.7143 - val_loss: 0.5692
Epoch 4/50
[1m147/147[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m247s[0m 2s/step - acc: 0.6897 - loss: 0.5870 - val_acc: 0.7231 - val_loss: 0.5607
Epoch 5/50
[1m147/147[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m249s[0m 2s/step - acc: 0.6954 - loss: 0.5802 - val_acc: 0.7282 - val_loss: 0.5542
Epoch 6/50
[1m147/147[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m251s[0m 2s/step - acc: 0.7062 - loss: 0.5706 - val_acc: 0.7327 - val_loss: 0.5492
Epoch 7/50
[1m147/147[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m251s[0m 2s/