Author: Tyler Chase

Date: 2017/05/18

# Model NSFW Classification

This code uses an AlexNet model to classify an image as not safe for work (nsfw) or safe for work (sfw). 

## Load Datasets

In [None]:
import tensorflow as tf
import numpy as np
import math
import timeit
import random
import pickle
import matplotlib.pyplot as plt
import itertools
from sklearn.metrics import confusion_matrix
from util import import_dataset
from model import Model, lazy_property
from config import ModelConfig, TrainConfig
%matplotlib inline

# Set default to auto import packages
%load_ext autoreload
%autoreload 2

In [None]:
# Form training, developement, and testing data sets
address = r'/home/tylerchase/CS-231N-Final-Project/data/fullData//'
address = r'../../data/fullData//'
file_names = {}
file_names['images'] = 'full_data.npy'
file_names['subs'] = 'full_subredditlabels'
file_names['dict'] = 'full_subredditIndex'
file_names['nsfw'] = 'full_nsfwlabels'
data, dictionary = import_dataset(address, file_names)

# Print the sizes as a sanity check
print('Train data shape: ', data.X_train.shape)
print('Train subreddit labels shape: ', data.y_train.shape)
print('Train nsfw labels shape: ', data.y_train_2.shape)
print('Validation data shape: ', data.X_val.shape)
print('Validation subreddit labels shape: ', data.y_val.shape)
print('Validation nsfw labels shape: ', data.y_val_2.shape)
print('Test data shape: ', data.X_test.shape)
print('Test subreddit labels shape: ', data.y_test.shape)
print('Test nsfw labels shape: ', data.y_test_2.shape)

## Determine Subreddit Statistics of Training Set

In [3]:
# Print and store subreddits and subreddit totals
num_subs = len(dictionary)
classes = [""] * num_subs
stats = [0] * num_subs

# Form Array of Subreddits
for sub, ind in dictionary.items():
    classes[ind] = sub

# Form array of Subreddit statistics and print
for i, j in enumerate(classes):
    temp = np.sum(i == data.y_train)
    stats[i] = temp
    print(j + ' Submissions: ', temp)
print('Sanity Check Sum: ', np.sum(stats))

# Print total submissions
total = np.shape(data.y_train)[0]
print('\nTotal Submissions: ', total)

## Determine NSFW Statistics in Training Dataset

In [4]:
dict_nsfw = {}
dict_nsfw['NSFW'] = 1
dict_nsfw['SFW'] = 0

# Print and store NSFW and NSFW totals
num_out = len(dict_nsfw)
classes_nsfw = [""] * num_out
stats_nsfw = [0] * num_out
for category, ind in dict_nsfw.items():
    classes_nsfw[ind] = category
    temp = np.sum(ind == data.y_train_2)
    stats_nsfw[ind] = temp
    print(category + ' Submissions: ', temp)
print('Sanity Check Sum: ', np.sum(stats_nsfw))

total_nsfw = np.shape(data.y_train_2)[0]
print('\nTotal Submissions: ', total_nsfw)

NSFW Submissions:  2164
SFW Submissions:  23286
Sanity Check Sum:  25450

Total Submissions:  25450


## Determine NSFW Images Per Subreddit

In [5]:
nsfw_breakdown = {}

# Store and print NSFW breakdown of each Subreddit
for i,j in enumerate(classes):
    nsfw_sub = {}
    class_indices = np.argwhere(data.y_train == i)
    nsfw_subset = data.y_train_2[class_indices]
    nsfw_sub['nsfw'] = np.sum(nsfw_subset == 1)
    nsfw_sub['sfw'] = np.sum(nsfw_subset == 0)
    nsfw_breakdown[j] = nsfw_sub
    print(j, ': ', nsfw_sub['nsfw'] + nsfw_sub['sfw'])
    print('NSFW: ', nsfw_sub['nsfw'])
    print('SFW: ', nsfw_sub['sfw'])
    print()


EarthPorn :  1362
NSFW:  0
SFW:  1362

SkyPorn :  1359
NSFW:  0
SFW:  1359

spaceporn :  1307
NSFW:  0
SFW:  1307

MilitaryPorn :  1331
NSFW:  4
SFW:  1327

GunPorn :  1324
NSFW:  1
SFW:  1323

carporn :  1320
NSFW:  1
SFW:  1319

CityPorn :  1336
NSFW:  0
SFW:  1336

ruralporn :  969
NSFW:  0
SFW:  969

ArchitecturePorn :  1278
NSFW:  0
SFW:  1278

FoodPorn :  1364
NSFW:  0
SFW:  1364

MoviePosterPorn :  1354
NSFW:  10
SFW:  1344

ArtPorn :  1349
NSFW:  110
SFW:  1239

RoomPorn :  1357
NSFW:  0
SFW:  1357

creepy :  1306
NSFW:  98
SFW:  1208

gonewild :  982
NSFW:  982
SFW:  0

PrettyGirls :  1329
NSFW:  0
SFW:  1329

ladybonersgw :  917
NSFW:  917
SFW:  0

LadyBoners :  1191
NSFW:  39
SFW:  1152

cats :  1356
NSFW:  1
SFW:  1355

dogpictures :  1359
NSFW:  1
SFW:  1358



## Balancing the SFW/NSFW Data Content

Since safe for work SFW content takes up approximately 90% of the data here we balance the data by only considering 4 subreddits that are pictures of people. r/gonewild and r/ladybonersgw both contain mostly nsfw content and are women and men respectively. r/prettygirls and r/ladyboners both contain mostly sfw content and are women and men respectively.   

In [6]:
subreddits_of_interest = ['gonewild', 'ladybonersgw', 'PrettyGirls', 'LadyBoners']
total = 0
for j,i in enumerate(subreddits_of_interest):
    if j==0:
        print(i)
        index = dictionary[i] == data.y_train
        found = np.sum(index)
        print('posts found: ', found)
        print()
        total+=found
        data_subset = data.X_train[index]
        out_subset = data.y_train[index]
        out_subset_2 = data.y_train_2[index]
    else:
        print(i)
        index = dictionary[i] == data.y_train
        found = np.sum(index)
        print('posts found: ', found)
        print()
        total+=found
        data_subset = np.concatenate((data_subset, data.X_train[index]), axis = 0)
        out_subset = np.concatenate((out_subset, data.y_train[index]), axis = 0)
        out_subset_2 = np.concatenate((out_subset_2, data.y_train_2[index]), axis = 0)
        
print('sanity check')
print('posts found: ', total)
print('length training: ', np.shape(data_subset)[0])
        
# Permute the training data for training 
SEED = 455
random.seed(SEED)
N = np.shape(out_subset)[0]
indices = np.arange(N)
random.shuffle(indices)
data.X_train = data_subset[indices]
data.y_train = out_subset[indices]
data.y_train_2 = out_subset_2[indices]

gonewild
posts found:  982

ladybonersgw
posts found:  917

PrettyGirls
posts found:  1329

LadyBoners
posts found:  1191

sanity check
posts found:  4419
length training:  4419


## Check the Subreddit Statistics

In [7]:
# Print and store subreddits and subreddit totals
num_subs = len(dictionary)
classes = [""] * num_subs
stats = [0] * num_subs

# Form Array of Subreddits
for sub, ind in dictionary.items():
    classes[ind] = sub

# Form array of Subreddit statistics and print
for i, j in enumerate(classes):
    temp = np.sum(i == data.y_train)
    stats[i] = temp
    print(j + ' Submissions: ', temp)
print('Sanity Check Sum: ', np.sum(stats))

# Print total submissions
total = np.shape(data.y_train)[0]
print('\nTotal Submissions: ', total)

EarthPorn Submissions:  0
SkyPorn Submissions:  0
spaceporn Submissions:  0
MilitaryPorn Submissions:  0
GunPorn Submissions:  0
carporn Submissions:  0
CityPorn Submissions:  0
ruralporn Submissions:  0
ArchitecturePorn Submissions:  0
FoodPorn Submissions:  0
MoviePosterPorn Submissions:  0
ArtPorn Submissions:  0
RoomPorn Submissions:  0
creepy Submissions:  0
gonewild Submissions:  982
PrettyGirls Submissions:  1329
ladybonersgw Submissions:  917
LadyBoners Submissions:  1191
cats Submissions:  0
dogpictures Submissions:  0
Sanity Check Sum:  4419

Total Submissions:  4419


## Check the NSF Statistics

In [8]:
dict_nsfw = {}
dict_nsfw['NSFW'] = 1
dict_nsfw['SFW'] = 0

# Print and store NSFW and NSFW totals
num_out = len(dict_nsfw)
classes_nsfw = [""] * num_out
stats_nsfw = [0] * num_out
for category, ind in dict_nsfw.items():
    classes_nsfw[ind] = category
    temp = np.sum(ind == data.y_train_2)
    stats_nsfw[ind] = temp
    print(category + ' Submissions: ', temp)
print('Sanity Check Sum: ', np.sum(stats_nsfw))

total_nsfw = np.shape(data.y_train_2)[0]
print('\nTotal Submissions: ', total_nsfw)

NSFW Submissions:  1938
SFW Submissions:  2481
Sanity Check Sum:  4419

Total Submissions:  4419


## Determine NSFW Images Per Subreddit

In [9]:
nsfw_breakdown = {}

# Store and print NSFW breakdown of each Subreddit
for i,j in enumerate(classes):
    nsfw_sub = {}
    class_indices = np.argwhere(data.y_train == i)
    nsfw_subset = data.y_train_2[class_indices]
    nsfw_sub['nsfw'] = np.sum(nsfw_subset == 1)
    nsfw_sub['sfw'] = np.sum(nsfw_subset == 0)
    nsfw_breakdown[j] = nsfw_sub
    print(j, ': ', nsfw_sub['nsfw'] + nsfw_sub['sfw'])
    print('NSFW: ', nsfw_sub['nsfw'])
    print('SFW: ', nsfw_sub['sfw'])
    print()


EarthPorn :  0
NSFW:  0
SFW:  0

SkyPorn :  0
NSFW:  0
SFW:  0

spaceporn :  0
NSFW:  0
SFW:  0

MilitaryPorn :  0
NSFW:  0
SFW:  0

GunPorn :  0
NSFW:  0
SFW:  0

carporn :  0
NSFW:  0
SFW:  0

CityPorn :  0
NSFW:  0
SFW:  0

ruralporn :  0
NSFW:  0
SFW:  0

ArchitecturePorn :  0
NSFW:  0
SFW:  0

FoodPorn :  0
NSFW:  0
SFW:  0

MoviePosterPorn :  0
NSFW:  0
SFW:  0

ArtPorn :  0
NSFW:  0
SFW:  0

RoomPorn :  0
NSFW:  0
SFW:  0

creepy :  0
NSFW:  0
SFW:  0

gonewild :  982
NSFW:  982
SFW:  0

PrettyGirls :  1329
NSFW:  0
SFW:  1329

ladybonersgw :  917
NSFW:  917
SFW:  0

LadyBoners :  1191
NSFW:  39
SFW:  1152

cats :  0
NSFW:  0
SFW:  0

dogpictures :  0
NSFW:  0
SFW:  0



## Define AlexNet model 

* 11x11 convolutional layer with 96 filters and a stride of 4
* ReLU activation
* 3x3 max pooling with a stride of 2
* batch normalization


* 5x5 convolutional layer with 256 filters and a stride of 1
* ReLU activation
* 3x3 max pooling with a stride of 2
* batch normalization


* 3x3 convolutional layer with 384 filters and a stride of 1
* ReLU activation
* 3x3 convolutional layer with 384 filters and a stride of 1
* ReLU activation 
* 3x3 convolutional layer with 256 filters and a stride of 1
* ReLU activation
* 3x3 max pooling with a stride of 2


* affine layer from 4096 to 4096
* ReLU activation
* affine layer from 4096 to 4096
* ReLU activation
* affine layer from 4096 to 2

In [None]:
class AlexNet(Model):
    
    def __init__(self, model_config):
        Model.__init__(self, model_config)
  
    @lazy_property
    def prediction(self):
        # define our graph (e.g. AlexNet)
        
        a1 = tf.layers.conv2d(self.X_placeholder, filters=96, kernel_size=(11,11), strides=(4,4), padding='SAME') 
        h1 = tf.nn.relu(a1)
        mp1 = tf.layers.max_pooling2d(h1, pool_size=(3,3), strides=(2,2), padding='SAME')    
        bn1 = tf.layers.batch_normalization(mp1, training=self.is_training_placeholder)
        
        a2 = tf.layers.conv2d(bn1, filters=256, kernel_size=(5,5), strides=(1,1), padding='SAME')     
        h2 = tf.nn.relu(a2)
        mp2 = tf.layers.max_pooling2d(h2, pool_size=(3,3), strides=(2,2), padding='SAME')    
        bn2 = tf.layers.batch_normalization(mp2, training=self.is_training_placeholder)              
    
        a3 = tf.layers.conv2d(bn2, filters=384, kernel_size=(3,3), strides=(1,1), padding='SAME')    
        h3 = tf.nn.relu(a3)
        a4 = tf.layers.conv2d(h3, filters=384, kernel_size=(3,3), strides=(1,1), padding='SAME')   
        h4 = tf.nn.relu(a4)
        a5 = tf.layers.conv2d(h4, filters=256, kernel_size=(3,3), strides=(1,1), padding='SAME')    
        h5 = tf.nn.relu(a5)
        mp3 = tf.layers.max_pooling2d(h5, pool_size=(3,3), strides=(2,2), padding='SAME')  
    
        mp_flat = tf.reshape(mp3,[-1,4096])
        aff1 = tf.layers.dense(mp_flat, 4096)
        h6 = tf.nn.relu(aff1)
        aff2 = tf.layers.dense(h6, 4096)
        h7 = tf.nn.relu(aff2)
        y_out = tf.layers.dense(h7, self.config.class_size_2)
    
        return y_out     

## Train the Model

In [None]:
# Create model instance
tf.reset_default_graph()

model_config = ModelConfig(learning_rate=0.003, output = 'nsfw')
train_config = TrainConfig(num_epochs=5, minibatch_size=100, print_every=100, \
    saver_address=r'../../subreddit_classification_parameters/', \
    save_file_name = 'AlexNet_nsfw_classification')
model = AlexNet(model_config)

# Create session
session = tf.Session()
model.train(data, session, train_config)

## Return Loss and Accuracy History

In [None]:
# Plot Loss and Accuracy
model.plot_loss_acc(data)

## Test Loading the Model

In [None]:
# Reset Graph
tf.reset_default_graph()

# Create model instance
model_config = ModelConfig(learning_rate=0.003, output = 'nsfw')
train_config = TrainConfig(num_epochs=2, minibatch_size=100, print_every=100, \
    saver_address=r'../../subreddit_classification_parameters/', \
    save_file_name = 'AlexNet_nsfw_classification')
model = AlexNet(model_config)

# Load Saved Model
session = tf.Session()
saver = tf.train.Saver()
saver.restore(session, train_config.saver_address + train_config.save_file_name) 

# Test Model Accuracy
loss_train, acc_train = model.eval(data, session, split='train')
loss_val, acc_val = model.eval(data, session, split = 'val')

print('Training Accuracy {:3.1f}%, Vallidation Accuracy:{:3.1f}%'.format((100*acc_train), (100*acc_val)))

## Output Predictions for Vallidation

In [None]:
y_val_pred = session.run(model.prediction, {model.X_placeholder: data.X_val, model.y_placeholder: data.y_val_2, 
                                            model.is_training_placeholder:False})

y_val_pred = np.argmax(y_val_pred, axis = 1)

## Plot Confusion Matrix for nsfw Classification

In [None]:
# Code to plot the confusion matrix
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion Matrix',
                          cmap=plt.cm.Blues,
                          save_address = ''):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.figure(figsize=(6,6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, round(cm[i, j],2),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()
    #plt.savefig(save_address + 'confusion_mat.png')

classes = ['sfw', 'nsfw']

conf = confusion_matrix(data.y_val_2, y_val_pred)
plot_confusion_matrix(conf, classes=classes, normalize = True)
