In [1]:
import tensorflow as tf
import numpy as np
from config import ModelConfig, TrainConfig
import pickle

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
from util import import_dataset

address = '../../data/'
file_names = {}
file_names['images'] = 'full_data.npy'
file_names['subs'] = 'full_subredditlabels'
file_names['dict'] = 'full_subredditIndex'
file_names['nsfw'] = 'full_nsfwlabels'
data, dictionary = import_dataset(address, file_names)

In [3]:
print(data.X_train.shape)
print(data.X_val.shape)
print(data.X_test.shape)
print(data.y_train.shape)
print(data.y_val.shape)
print(data.y_test.shape)

(25450, 128, 128, 3)
(3181, 128, 128, 3)
(3182, 128, 128, 3)
(25450,)
(3181,)
(3182,)


## Loading models

#### Restore AlexNet

In [8]:
from alexnet import AlexNetMulticlass

# Reset Graph
tf.reset_default_graph()

# Create model instance
model_config = ModelConfig(eval_batch_size=2000)
model = AlexNetMulticlass(model_config)

# Load Saved Model
sess = tf.Session()
saver = tf.train.Saver()
save_file = "../../saved_params/AlexNet_multitask_classification_postParamSearch"
saver.restore(sess, save_file) 
saved_history = pickle.load(open(save_file + "_modelhist", 'rb'))
model.model_history = saved_history

# Test Model Accuracy
loss_train, acc_sbrd_train, acc_nsfw_train = model.eval(data, sess, split='train')
loss_val, acc_sbrd_val, ac_nsfw_val = model.eval(data, sess, split = 'val')

subreddit train accuracy:52.0%
nsfw train accuracy:95.1%
subreddit val accuracy:42.9%
nsfw val accuracy:94.2%


In [9]:
# Get model predictions
alex_sbrd_logits, alex_nsfw_logits = sess.run(model.prediction, {model.X_placeholder: data.X_test, \
                                                               model.y_sbrd_placeholder: data.y_test, \
                                                               model.y_nsfw_placeholder: data.y_test_2, \
                                                               model.is_training_placeholder:False})
alex_sbrd_pred = np.argmax(alex_sbrd_logits, axis = 1)
alex_nsfw_pred = np.argmax(alex_nsfw_logits, axis = 1)

#### Restore GoogleNet

In [10]:
from googlenet import GoogleNetMulticlass

# Reset Graph
tf.reset_default_graph()

# Create model instance
model_config = ModelConfig(eval_batch_size=3000)
model = GoogleNetMulticlass(model_config)

# Load Saved Model
sess = tf.Session()
saver = tf.train.Saver()
save_file = "../../saved_params/GoogleNet_multitask_classification_4e-4_99"
saver.restore(sess, save_file) 
saved_history = pickle.load(open(save_file + "_modelhist", 'rb'))
model.model_history = saved_history

# Test Model Accuracy
loss_train, acc_sbrd_train, acc_nsfw_train = model.eval(data, sess, split='train')
loss_val, acc_sbrd_val, ac_nsfw_val = model.eval(data, sess, split = 'val')

subreddit train accuracy:93.6%
nsfw train accuracy:98.7%
subreddit val accuracy:64.1%
nsfw val accuracy:96.7%


In [11]:
# Get model predictions
goog_sbrd_logits, goog_nsfw_logits = sess.run(model.prediction, {model.X_placeholder: data.X_test, \
                                                               model.y_sbrd_placeholder: data.y_test, \
                                                               model.y_nsfw_placeholder: data.y_test_2, \
                                                               model.is_training_placeholder:False})
goog_sbrd_pred = np.argmax(goog_sbrd_logits, axis = 1)
goog_nsfw_pred = np.argmax(goog_nsfw_logits, axis = 1)

#### Load ResNet

In [None]:
from resnet import ResNetMulticlass

# Reset Graph
tf.reset_default_graph()

# Create model instance
model_config = ModelConfig(eval_batch_size=3000)
model = ResNetMulticlass(model_config)

# Load Saved Model
sess = tf.Session()
saver = tf.train.Saver()
save_file = "../../saved_params/ResNet_multitask_final"
saver.restore(sess, save_file) 
saved_history = pickle.load(open(save_file + "_modelhist", 'rb'))
#model.model_history = saved_history

# Test Model Accuracy
loss_train, acc_sbrd_train, acc_nsfw_train = model.eval(data, sess, split='train')
loss_val, acc_sbrd_val, ac_nsfw_val = model.eval(data, sess, split = 'val')

In [None]:
# Get model predictions
res_sbrd_logits, res_nsfw_logits = sess.run(model.prediction, {model.X_placeholder: data.X_test, \
                                                               model.y_sbrd_placeholder: data.y_test, \
                                                               model.y_nsfw_placeholder: data.y_test_2, \
                                                               model.is_training_placeholder:False})
res_sbrd_pred = np.argmax(res_sbrd_logits, axis = 1)
res_nsfw_pred = np.argmax(res_nsfw_logits, axis = 1)

### Run this only when ResNet model can't be properly imported

In [13]:
import pickle
res_sbrd_logits = pickle.load(open('../../test_sbrd_logits.dat', 'rb'))
res_nsfw_logits = pickle.load(open('../../test_nsfw_logits.dat', 'rb'))
res_sbrd_pred = pickle.load(open('../../test_sbrd_classes.dat', 'rb'))
res_nsfw_pred = pickle.load(open('../../test_nsfw_classes.dat', 'rb'))

## Ensembling

In [53]:
# predictions is a list of prediction arrays, one for each model
# default_prediction is a prediction array, that is defaulted to when there is no clear majority
# returns ensembled predictions
def majority_vote_ensemble(predictions, default_prediction):
    from scipy import stats
    
    predictions = np.array(predictions)
    num_models = predictions.shape[1]
    mode, counts = stats.mode(predictions)
    counts = counts[0]
    mode = mode[0]
    indices = counts < (num_models / 2.0)  # indices of predictions with less than majority vote
    vote_pred = mode
    vote_pred[indices] = default_prediction[indices]
    return vote_pred

In [54]:
# Used from CS 224N code
def softmax(x):
    shifted = x - np.max(x)
    exponentiated = np.exp(shifted)
    return exponentiated / np.sum(exponentiated)

# predictions is a list of logit arrays, one for each model
# weights is a list of weights for each model, in the same order as predictions
# returns ensembled predictions
def average_ensemble(logits, weights):
    probs = np.array([softmax(x) for x in logits])
    average_probs = np.average(probs, axis=0, weights=weights)
    return np.argmax(average_probs, axis=1)

In [67]:
def accuracy(prediction, actual):
    return np.average(prediction == actual)

sbrd_logits = [alex_sbrd_logits, goog_sbrd_logits, res_sbrd_logits]
nsfw_logits = [alex_nsfw_logits, goog_nsfw_logits, res_nsfw_logits]
sbrd_preds = [alex_sbrd_pred, goog_sbrd_pred, res_sbrd_pred]
nsfw_preds = [alex_nsfw_pred, goog_nsfw_pred, res_nsfw_pred]

weights = [0.1, 0.7, 0.2]
sbrd_majority_preds = majority_vote_ensemble(sbrd_preds, res_sbrd_pred)
nsfw_majority_preds = majority_vote_ensemble(nsfw_preds, res_nsfw_pred)
sbrd_average_preds = average_ensemble(sbrd_logits, weights)
nsfw_average_preds = average_ensemble(nsfw_logits, weights)

sbrd_majority_acc = accuracy(sbrd_majority_preds, data.y_test)
nsfw_majority_acc = accuracy(nsfw_majority_preds, data.y_test_2)
sbrd_average_acc = accuracy(sbrd_average_preds, data.y_test)
nsfw_average_acc = accuracy(nsfw_average_preds, data.y_test_2)

print("Majority vote accuracies:")
print("Subreddit: " + str(sbrd_majority_acc))
print("NSFW: " + str(nsfw_majority_acc))
print("Average accuracies:")
print("Subreddit: " + str(sbrd_average_acc))
print("NSFW: " + str(nsfw_average_acc))

Majority vote accuracies:
Subreddit: 0.494028912634
NSFW: 0.944060339409
Average accuracies:
Subreddit: 0.473601508485
NSFW: 0.949402891263
