In [1]:
import pandas as pd
import numpy as np
import random
import pickle
from os import path

from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras.applications import ResNet50
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping

In [2]:
# preprocess dataframe
poster_dir = "../data/posters/"
poster_df = pd.read_csv("../data/MovieGenre.csv", encoding = "ISO-8859-1")
poster_df = poster_df.drop_duplicates(subset=['imdbId'], keep="last")
poster_df["Genre"] = poster_df["Genre"].str.split("|")
poster_df["filename"] = poster_df["imdbId"].astype(str) + ".jpg"
poster_df = poster_df[poster_df["filename"].apply(lambda t: path.exists(poster_dir + t))]
poster_df = poster_df.dropna()
poster_df = poster_df[poster_df["Genre"].apply(lambda t: isinstance(t, list))]
#poster_df = poster_df.iloc[random.sample(range(0, 30000), 10000)]

In [3]:
# create image data generator
datagen = ImageDataGenerator(rescale=1./255., validation_split=0.25)

train_generator = datagen.flow_from_dataframe(dataframe=poster_df,
                                              directory=poster_dir,
                                              x_col="filename",
                                              y_col="Genre",
                                              subset="training",
                                              batch_size=32,
                                              shuffle=True,
                                              seed=42,
                                              class_mode="categorical",
                                              target_size=(64, 64))


valid_generator=datagen.flow_from_dataframe(dataframe=poster_df,
                                            directory=poster_dir,
                                            x_col="filename",
                                            y_col="Genre",
                                            subset="validation",
                                            batch_size=32,
                                            shuffle=True,
                                            seed=42,
                                            class_mode="categorical",
                                            target_size=(64, 64))

num_classes = len(train_generator.class_indices)
print(train_generator.class_indices)

Found 27321 validated image filenames belonging to 28 classes.
Found 9107 validated image filenames belonging to 28 classes.
{'Action': 0, 'Adult': 1, 'Adventure': 2, 'Animation': 3, 'Biography': 4, 'Comedy': 5, 'Crime': 6, 'Documentary': 7, 'Drama': 8, 'Family': 9, 'Fantasy': 10, 'Film-Noir': 11, 'Game-Show': 12, 'History': 13, 'Horror': 14, 'Music': 15, 'Musical': 16, 'Mystery': 17, 'News': 18, 'Reality-TV': 19, 'Romance': 20, 'Sci-Fi': 21, 'Short': 22, 'Sport': 23, 'Talk-Show': 24, 'Thriller': 25, 'War': 26, 'Western': 27}


In [4]:
# use ResNet50 model for classification
base_model = ResNet50(include_top=False, weights=None, input_shape=(64, 64, 3))

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(128, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)

model = Model(base_model.input, predictions)

adam = Adam(learning_rate=1e-2, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
callback = EarlyStopping(patience=3)
model.compile(optimizer=adam, loss='binary_crossentropy',
              metrics=['accuracy'])

STEP_SIZE_TRAIN=train_generator.n//train_generator.batch_size
STEP_SIZE_VALID=valid_generator.n//valid_generator.batch_size

history = model.fit(x=train_generator,
                    steps_per_epoch=STEP_SIZE_TRAIN,
                    epochs=25,
                    validation_data=valid_generator,
                    validation_steps=STEP_SIZE_VALID,
                    callbacks=[callback])

Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25


In [5]:
# save the model for deployment
pickle.dump(model, open('poster_predictor_keras.pkl', 'wb+'))

TypeError: cannot pickle 'weakref' object

Reference:
1. https://godatadriven.com/blog/keras-multi-label-classification-with-imagedatagenerator/