In [46]:
import numpy as np
import os
from model import simple_model
from dataset import get_data_generators
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical

In [47]:
base_dir = '../pdata/'
train_dir = os.path.join(base_dir, 'train')
test_dir = os.path.join(base_dir, 'test')
validation_dir = os.path.join(base_dir, 'validation')
datagen = ImageDataGenerator(rescale=1./255)

In [50]:
def extract_features(directory, sample_count):
    features = np.zeros(shape=(sample_count,9,9,512))
    labels = np.zeros(shape=(sample_count))
    generator = datagen.flow_from_directory(
        directory,
        target_size=(300,300),
        batch_size=batch_size,
        class_mode='binary')
    i=0
    for inputs_batch, labels_batch in generator:
        features_batch = conv_base.predict(inputs_batch)
        features[i*batch_size: (i+1)*batch_size] = features_batch
        labels[i*batch_size: (i+1)*batch_size] = labels_batch
        i += 1
        if i * batch_size >= sample_count:
            break
    return features, to_categorical(labels)

In [4]:
# Conv Base VGG16
from tensorflow.keras.applications import VGG16
conv_base = VGG16(include_top=False, input_shape=(300,300,3))
conv_base.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 300, 300, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 300, 300, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 300, 300, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 150, 150, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 150, 150, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 150, 150, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 75, 75, 128)       0         
__________

In [12]:
def extract_VGG16_features():
    # Use a VGG16 model to extract features
    train_features, train_labels = extract_features(train_dir, 268)
    validation_features, validation_labels = extract_features(validation_dir, 163)
    test_features, test_labels = extract_features(test_dir, 106)
    
    train_features = train_features.reshape(268, 9*9*512)
    validation_features = validation_features.reshape(163, 9*9*512)
    test_features = test_features.reshape(106, 9*9*512)
    extracted_data = [train_features, train_labels, validation_features, validation_labels, test_features, test_labels]
    return extracted_data

In [40]:
def transfer_model():
    model = models.Sequential()
    model.add(layers.Dense(64, activation='relu', input_shape=(41472,)))
    model.add(layers.Dropout(0.3))
    model.add(layers.Dense(128, activation='relu'))
    model.add(layers.Dropout(0.3))
    model.add(layers.Dense(5, activation='softmax'))
    return model

In [71]:

NUM_EPOCHS = 3
BATCH_SIZE = 2
STEPS_PER_EPOCH = 2
VALIDATION_STEPS = 2

def main():
    model = transfer_model()
    train_features, train_labels, validation_features, validation_labels, test_features, test_labels = extract_VGG16_features()
    model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
    history = model.fit(train_features, train_labels,batch_size=3,
		epochs=NUM_EPOCHS,
		validation_data=(validation_features,validation_labels))

    test_loss, test_acc = model.evaluate(test_features, test_labels)
    print("Test Loss = {}, Test Accuracy = {}".format(test_loss, test_acc))

In [72]:
main()

Train on 268 samples, validate on 163 samples
Epoch 1/3
Epoch 2/3
Epoch 3/3
Test Loss = 1.0371726751327515, Test Accuracy = 0.49999999887538404
