In [1]:
import tensorflow as tf
from tensorflow.keras import losses, optimizers, layers
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.applications import MobileNet


import numpy as np
import cv2
import matplotlib.pyplot as plt
import os


BASE_DIR = 'data/100-bird-species/'
TRAIN_DIR = os.path.join(BASE_DIR, 'train')
VALIDATION_DIR = os.path.join(BASE_DIR, 'valid')
TEST_DIR = os.path.join(BASE_DIR, 'test')
CATEGORIES = os.listdir(TRAIN_DIR) # 175

In [None]:
train_data = []
train_labels = []

test_data = []
test_labels = []


for category in os.listdir(TRAIN_DIR):
    path = os.path.join(TRAIN_DIR, category)
    for image in os.listdir(path):
        img = cv2.imread(os.path.join(path, image))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        label = CATEGORIES.index(category)
        train_data.append(img)
        train_labels.append(label)

        
for category in os.listdir(TEST_DIR):
    path = os.path.join(TEST_DIR, category)
    for image in os.listdir(path):
        img = cv2.imread(os.path.join(path, image))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        label = CATEGORIES.index(category)
        test_data.append(img)
        test_labels.append(label)


In [None]:
train_data = np.asarray(train_data).astype('float32') / 255
train_labels = to_categorical(train_labels)

test_data = np.asarray(test_data).astype('float32') / 255
test_labels = to_categorical(test_labels)

In [None]:
conv_base = MobileNet(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
conv_base.trainable = False

model = tf.keras.models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation=tf.nn.relu))
model.add(layers.Dense(len(CATEGORIES), activation=tf.nn.softmax))

model.compile(
    optimizer=optimizers.RMSprop(lr=0.001),
    loss=losses.categorical_crossentropy,
    metrics=['accuracy'],
)

In [None]:
history = model.fit(
    train_images, train_labels,
    batch_size=32,
    epochs=40,
)

In [None]:
model.save('100-bird-species.h5')

In [None]:
model.load_weights('100-bird-species.h5')

In [None]:
_, accuracy = model.evaluate(test_images, test_labels)
print('Accuracy: ', round(accuracy * 100, 2), '%')

plt.plot(history.history['accuracy'], 'r-')
plt.legend()
plt.grid()
plt.show()