In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
dataset, info = tfds.load("deep_weeds", as_supervised=True, with_info=True)

In [3]:
n_classes = info.features["label"].num_classes
n_classes
dataset_size = info.splits["train"].num_examples
dataset_size

17509

In [4]:
test_set = tfds.load("deep_weeds", as_supervised=True, split="train[:5%]")
valid_set = tfds.load("deep_weeds", as_supervised=True, split="train[5%:15%]")
train_set = tfds.load("deep_weeds", as_supervised=True, split="train[15%:35]")

In [5]:
def preprocess_resize(image, label):
    return tf.image.resize(image, [224,224]), label

def random_crop(image):
    shape = tf.shape(image)
    min_dim = tf.reduce_min([shape[0], shape[1]]) * 90 // 100
    return tf.image.random_crop(image, [min_dim, min_dim, 3])

def preprocess_augment(image, label):
    cropped_image = random_crop(image)
    cropped_image = tf.image.random_flip_left_right(cropped_image)
    return cropped_image, label

def preprocess_xception(image, label):
    return keras.applications.xception.preprocess_input(image), label

In [6]:
from tensorflow import keras

batch_size = 32
prefetch = 1

test_set_final = test_set.map(preprocess_resize).map(preprocess_xception)
test_set_final = test_set.batch(batch_size).prefetch(prefetch)

valid_set_final = valid_set.map(preprocess_resize).map(preprocess_xception)
valid_set_final = valid_set.batch(batch_size).prefetch(prefetch)

train_set_final = train_set.shuffle(1000).repeat()
train_set_final = train_set.map(preprocess_augment).map(preprocess_resize).map(preprocess_xception)
train_set_final = train_set_final.batch(batch_size).prefetch(prefetch)

In [7]:
base_model = keras.applications.xception.Xception(weights="imagenet", include_top=False)
avg = keras.layers.GlobalAveragePooling2D()(base_model.output)
output = keras.layers.Dense(n_classes, activation="softmax")(avg)
model = keras.Model(inputs=base_model.input, outputs=output)

In [8]:
for layer in base_model.layers:
    layer.trainable = False

In [9]:
optimizer = keras.optimizers.SGD(lr=0.2, momentum=0.9, decay=0.01)
model.compile(loss="sparse_categorical_crossentropy", optimizer=optimizer, metrics=["accuracy"])

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir=./training_logs --port 6008

In [10]:
from datetime import datetime

tf.debugging.set_log_device_placement(True)

logs = "./training_logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")

tboard_callback = tf.keras.callbacks.TensorBoard(log_dir = logs,
                                                 histogram_freq = 1,
                                                 profile_batch = '50,100')

history = model.fit(train_set_final, epochs=5, validation_data=valid_set_final, callbacks = [tboard_callback])

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5

KeyboardInterrupt: 