In [1]:
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import NASNetLarge

In [2]:
#Load Data
train = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function = tf.keras.applications.nasnet.preprocess_input
).flow_from_directory(
    "FER2013/train/",
    class_mode="categorical",
    color_mode="rgb",
    batch_size=256,
    target_size = (331, 331)
)
test = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function = tf.keras.applications.nasnet.preprocess_input
).flow_from_directory(
    "FER2013/test/",
    class_mode="categorical",
    color_mode="rgb",
    batch_size=256,
    target_size = (331, 331)
)

Found 28709 images belonging to 7 classes.
Found 7178 images belonging to 7 classes.


In [3]:
input_shape = (331, 331, 3)

In [4]:
base = NASNetLarge(
    input_shape=input_shape,
    include_top=False,
    weights="imagenet",
    pooling = 'avg'
)
base.trainable = False

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/nasnet/NASNet-large-no-top.h5


In [5]:
model = keras.Sequential(
    [
        base,
#         layers.Flatten(),
#         layers.Dense(256, activation='relu'),
#         layers.Dense(128, activation='relu'),
#         layers.Dense(64, activation='relu'),
        layers.Dense(7, activation = 'softmax')
    ]
)

In [6]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
NASNet (Functional)          (None, 4032)              84916818  
_________________________________________________________________
dense (Dense)                (None, 7)                 28231     
Total params: 84,945,049
Trainable params: 28,231
Non-trainable params: 84,916,818
_________________________________________________________________


In [7]:
model.compile(
    optimizer='adam',
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

In [8]:
check = tf.keras.callbacks.ModelCheckpoint(
    "weights",
    monitor="accuracy",
    verbose=0,
    save_freq="epoch",
)

In [9]:
model.fit(
    x = train,
    epochs = 10,
    callbacks = [check],
    use_multiprocessing = False
)

Epoch 1/10
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: weights/assets
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10

KeyboardInterrupt: 

In [None]:
model.evaluate(x=test, use_multiprocessing = False)