In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models


In [10]:
ds, ds_info = tfds.load('bccd', split=['train[:80%]', 'train[80%:]'], with_info=True)
train_ds, test_ds = ds


In [11]:
def extract_label(example):
    img = tf.image.convert_image_dtype(example['image'], tf.float32)
    label = example['objects']['label'][0]
    return img, label


In [12]:
train_ds = train_ds.map(extract_label).batch(32).shuffle(100)
test_ds = test_ds.map(extract_label).batch(32)


In [None]:
for img, label in train_ds.take(1):
    plt.imshow(img[0])
    plt.title(int(label[0]))
    plt.axis("off")
plt.show()


In [14]:
data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomContrast(0.2)
])
augmented_train_ds = train_ds.map(lambda x,y: (data_augmentation(x, training=True), y))


In [15]:
model = models.Sequential([
    layers.Input(shape=(None,None,3)),
    layers.Resizing(128,128),
    layers.Conv2D(32,3,activation="relu"),
    layers.MaxPooling2D(),
    layers.Conv2D(64,3,activation="relu"),
    layers.MaxPooling2D(),
    layers.Conv2D(128,3,activation="relu"),
    layers.MaxPooling2D(),
    layers.Flatten(),
    layers.Dense(128,activation="relu"),
    layers.Dense(3,activation="softmax")
])
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])


In [None]:
history_noaug = model.fit(train_ds, epochs=8, validation_data=test_ds)
model_aug = tf.keras.models.clone_model(model)
model_aug.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
history_aug = model_aug.fit(augmented_train_ds, epochs=8, validation_data=test_ds)

plt.plot(history_noaug.history["accuracy"], label="train_noaug")
plt.plot(history_noaug.history["val_accuracy"], label="val_noaug")
plt.plot(history_aug.history["accuracy"], label="train_aug")
plt.plot(history_aug.history["val_accuracy"], label="val_aug")
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.legend()
plt.show()


In [None]:
print("train acc before aug:", history_noaug.history["accuracy"][-1])
print("val acc before aug:", history_noaug.history["val_accuracy"][-1])
print("train acc after aug:", history_aug.history["accuracy"][-1])
print("val acc after aug:", history_aug.history["val_accuracy"][-1])
