In [1]:
%load_ext autoreload
%autoreload 2
from lib.feature_extractor import NASNetLargeExtractor

In [5]:
# download google nasnet large pre-trained model
model = NASNetLargeExtractor(32, 10)

In [16]:
# load cifar datasets
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

def preprocess_data(data_set):
    data_set /= 255.0
    return data_set

x_train = preprocess_data(x_train)
x_test = preprocess_data(x_test)

y_train = to_categorical(y_train, num_classes = 10)
y_test = to_categorical(y_test, num_classes = 10)

# split a validation set
x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=0.2, random_state=42)

In [24]:
# extract features
features_train = model.extract(x_train[:50])
print(features_train.shape)

(50, 4032)


In [18]:
# save features
model.save_features("dataset/cifar10/")

Extracted training set features saved


In [19]:
# load features
model.load_features("dataset/cifar10/")

ParserError: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.

In [25]:
# use dense layer to test feature quality
print(model.features.shape)
history = model.train_classifier(y_train[:50, :], epochs=2, batch_size=50, validation_data=(x_valid[:18], y_valid[:18,:]))

(50, 4032)
Extracting features for validation data
Epoch 1/2
Epoch 2/2
Restoring best model weights with validation accuracy: 0.3333333432674408


In [26]:
model.save_classifier("model/cifar10/")
model.save_extractor("model/cifar10/")

In [27]:
model = NASNetLargeExtractor(32, 10)
model.load_classifier("model/cifar10/")
model.load_extractor("model/cifar10/")

In [23]:
# fine-tune the network
history, features = model.extract_fine_tuned_features(x_train[:10], y_train[:10, :], batch_size=5, epochs=2, validation_data=(x_valid[:18], y_valid[:18,:]))

Epoch 1/2
Epoch 2/2
Restoring best model weights with validation accuracy: 0.3981481542189916


In [20]:
# save extractor
model.save_extractor("extractor.h5")

In [21]:
model.load_extractor("extractor.h5")

In [22]:
# save classification layer
model.save_classifier("classification.h5")

In [23]:
# load classicifation layer
model.load_classifier("classification.h5")