In [1]:
import tensorflow as tf

train_data, valid_data = tf.keras.utils.image_dataset_from_directory(
    "./cats_vs_dogs",
    label_mode="categorical",
    image_size=(150, 150),
    validation_split=0.3,
    subset="both",
    seed=0)


Found 1877 files belonging to 2 classes.
Using 1314 files for training.
Using 563 files for validation.


In [2]:
classes = train_data.class_names
classes

['cat', 'dog']

In [3]:
import joblib
joblib.dump(classes, "classes.joblib")

['classes.joblib']

In [4]:
n_classes = len(classes)
n_classes

2

In [5]:
n_batches = valid_data.cardinality()
n_batches

<tf.Tensor: shape=(), dtype=int64, numpy=18>

In [6]:
test_size = n_batches // 2
test_data = valid_data.take(test_size)
valid_data = valid_data.skip(test_size)

In [7]:
for transform in [
    tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.RandomRotation(0.1)]:
    
    train_data = train_data.map(lambda x, y: (transform(x), y))




In [8]:
base_model = tf.keras.applications.Xception(
    weights="imagenet",
    input_shape=(150, 150, 3),
    include_top=False)

base_model.trainable = False

inputs = tf.keras.Input(shape=(150, 150, 3))
scaling = tf.keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
outputs = scaling(inputs)
outputs = base_model(outputs, training=False)
outputs = tf.keras.layers.GlobalAveragePooling2D()(outputs)
outputs = tf.keras.layers.Dense(n_classes)(outputs)

model = tf.keras.Model(inputs, outputs)
model.summary(show_trainable=True)


Model: "model"
____________________________________________________________________________
 Layer (type)                Output Shape              Param #   Trainable  
 input_2 (InputLayer)        [(None, 150, 150, 3)]     0         Y          
                                                                            
 rescaling (Rescaling)       (None, 150, 150, 3)       0         Y          
                                                                            
 xception (Functional)       (None, 5, 5, 2048)        2086148   N          
                                                       0                    
                                                                            
 global_average_pooling2d (  (None, 2048)              0         Y          
 GlobalAveragePooling2D)                                                    
                                                                            
 dense (Dense)               (None, 2)                 4098 

In [9]:
model.compile(
    optimizer=tf.keras.optimizers.AdamW(),
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=[tf.keras.metrics.CategoricalAccuracy()])

epochs = 2
model.fit(train_data, epochs=epochs, validation_data=valid_data)
model.evaluate(test_data)

Epoch 1/2

Epoch 2/2


[0.33788612484931946, 0.0347222238779068]

In [10]:
# base_model.trainable = True
# model.summary(show_trainable=True)

# model.compile(
#     optimizer=tf.keras.optimizers.AdamW(1e-5),
#     loss=tf.keras.losses.CategoricalCrossentropy(),
#     metrics=[tf.keras.metrics.CategoricalAccuracy()])

# epochs = 1
# model.fit(train_data, epochs=epochs, validation_data=valid_data)
# model.evaluate(test_data)

In [11]:
model.save("model.keras")

In [12]:
%%writefile app.py
# !pip install gradio ipywidgets
import joblib
import tensorflow as tf
import gradio as gr
import numpy as np

model = tf.keras.models.load_model("model.keras")
classes = joblib.load("classes.joblib")

def predict(path):
    image = tf.keras.preprocessing.image.load_img(path, target_size=(150, 150))
    image = tf.keras.preprocessing.image.img_to_array(image)
    image = np.expand_dims(image, axis=0)

    predicted = model.predict(image)[0].argmax(axis=-1)
    return classes[predicted]

# https://www.gradio.app/guides
with gr.Blocks() as blocks:
    path = gr.Image(label="Image", type="filepath")
    label = gr.Textbox(label="Label")

    inputs = [path]
    outputs = [label]

    predict_btn = gr.Button("Predict")
    predict_btn.click(predict, inputs=inputs, outputs=outputs)

if __name__ == "__main__":
    blocks.launch() # Local machine only
    # blocks.launch(server_name="0.0.0.0") # LAN access to local machine
    # blocks.launch(share=True) # Public access to local machine
    # predict("cats_vs_dogs/cat/0.jpg")

Overwriting app.py


In [13]:
%run app.py

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.
