#Import necessary Libraries

In [1]:
import pandas as pd
import numpy as np
import tensorflow as tf

  from ._conv import register_converters as _register_converters


Import the 100 element glove vectors. Also tried with word2vec 300 vectors - didn't seem to do any better.

In [2]:
import gensim
W2V_PATH = 'word2vec/glove.6B.100d.w2vformat.txt'
w2v = gensim.models.KeyedVectors.load_word2vec_format(W2V_PATH, binary=False)

Using TensorFlow backend.


In [3]:
tf.test.is_gpu_available()

True

Extract training data and split into training and test sets

In [None]:
temp_data = pd.read_csv('data/train.csv')
split_num = int(len(temp_data)*0.8)
test_data = temp_data.iloc[split_num:]
train_data = temp_data.iloc[:split_num]
print(len(train_data))
print(len(test_data))

train_data

In [7]:
labels_train = train_data.as_matrix(columns=['toxic','severe_toxic','obscene','threat','insult','identity_hate'])
labels_test = test_data.as_matrix(columns=['toxic','severe_toxic','obscene','threat','insult','identity_hate'])

Parse punctuation - convert it into usable tokens.

In [8]:
import collections
import re

def clean_punc(input_string):
    proc_string = input_string.replace('<',' <less ')
    proc_string = proc_string.replace('>',' <greater> ')
    proc_string = re.sub("https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{2,256}\.[a-z]{2,6}\b([-a-zA-Z0-9@:%_\+.~#?&//=]*)",' <url> ',proc_string)
    proc_string = proc_string.replace(' <less ',' <less> ')
    proc_string = proc_string.replace('?',' <question> ')
    proc_string = proc_string.replace('...',' <suspension> ')
    proc_string = proc_string.replace('. ',' <period> ')
    proc_string = proc_string if not proc_string.endswith('.') else proc_string[:-1]
    proc_string = proc_string.replace('/',' <slash> ')
    proc_string = proc_string.replace('\\',' <backslash> ')
    proc_string = proc_string.replace('; ',' <semicolon> ')
    proc_string = proc_string.replace(': ',' <colon> ')
    proc_string = proc_string.replace(', ',' <comma> ')
    proc_string = proc_string.replace('!',' <exclame> ')
    proc_string = proc_string.replace('\n',' <newline> ')
    proc_string = proc_string.replace(' - ',' <dash> ')
    proc_string = proc_string.replace('""',' <quote> ')
    proc_string = proc_string.replace('"',' <quote> ')
    proc_string = proc_string.replace('(',' <openbracket> ')
    proc_string = proc_string.replace(')',' <closebracket> ')
    return proc_string

def clean_word(input_word):
    out_word = input_word.lower()
    if ( out_word.startswith("'") and out_word.endswith("'")):
        out_word = out_word[1:-1]
    
    if len(out_word)>0:
        out_word = out_word if not out_word[-1] in ['.',':',';',','] else out_word[:-1]
    
    return out_word
    

Count the number of occurences of each word - acoss both the training and test sets

In [9]:
comments = [clean_punc(comment) for comment in train_data.comment_text]
comment_words = []
for comment in comments:
    comment_words.append ([clean_word(word) for word in comment.split()])
flatten = lambda l: [item for sublist in l for item in sublist]

flat_comments = flatten(comment_words)

word_counts = collections.Counter()
for word in flat_comments:
    word_counts[word]+=1

In [10]:
test_comments = [clean_punc(comment) for comment in test_data.comment_text]
test_comment_words = []
for comment in test_comments:
    test_comment_words.append ([clean_word(word) for word in comment.split()])
    
flat_comments = flatten(test_comment_words)

for word in flat_comments:
    word_counts[word]+=1

Output the total number of distinct words - and build a list of the 100 most common (and so, least useful).

In [11]:
print("Total words: {}".format(len(word_counts)))

very_common = [word for word,_ in word_counts.most_common(100)]

Total words: 277174


# Time to start on the embedding
First configure the hyper-parameters

In [12]:
# Embedding Hyper-paramters
comment_length = 200
embed_size = 100
n_labels = 6

Build a filtered word list to use for training - excluding the most common words and anything appearing less than 6 times.

In [13]:
filtered_words = set([word for num,word in enumerate(word_counts.keys()) if word_counts[word]>5 and word not in very_common])
len_embedding = len(filtered_words)
len_embedding

40434

Define mapping dictionaries for words and a seed embedding which uses values from glove where possible and random values otherwise.

In [49]:
word_to_int = {word:num for num,word in enumerate(filtered_words)}
int_to_word = {num:word for num,word in enumerate(filtered_words)}

embeddings = np.zeros([len_embedding,embed_size])

n_mapped = 0
for word in filtered_words:
    if word in w2v.vocab:
        embeddings[word_to_int[word],:] = w2v[word]
        n_mapped+=1
    else:
        embeddings[word_to_int[word],:] = np.random.uniform(size=[1,embed_size])

print("Mapped {} of {} words.".format(n_mapped,len(filtered_words)))

Mapped 33572 of 40434 words.


In [15]:
def map_word(in_word):
    work_word = clean_word(in_word)
    return word_to_int[work_word]
    

In [16]:
def process_comment(input_comment):
    result_matrix = np.zeros((comment_length))
    temp_matrix = [word_to_int[word] for word in input_comment if word in filtered_words]
    if (len(temp_matrix) == 0):
        return result_matrix

    temp_matrix = temp_matrix[-comment_length:]
    
    result_matrix[-len(temp_matrix):] = temp_matrix
    return result_matrix

In [17]:
import time

start = time.perf_counter()
#Pre-build integer arrays
print("Training comments:")
comment_ints =[]
for i in range(0,len(comment_words)):
    comment_ints.append(process_comment(comment_words[i]))
    if (i%100==0 and i >0):
        elapsed = time.perf_counter() - start
        print("\rProcessed {}/{} in {}.  ETA {}.".format(i,len(comment_words),elapsed,(len(comment_words)-i)*elapsed/i),end='')

start = time.perf_counter()
print("\nTest comments:")
test_comment_ints =[]
for i in range(0,len(test_comment_words)):
    test_comment_ints.append(process_comment(test_comment_words[i]))
    if (i%100==0 and i >0):
        elapsed = time.perf_counter() - start
        print("\rProcessed {}/{} in {}.  ETA {}.".format(i,len(comment_words),elapsed,(len(comment_words)-i)*elapsed/i),end='')

np.array(comment_ints).shape

Training comments:


Processed 100/127656 in 0.001646717999392422.  ETA 2.1004876113049975.Processed 200/127656 in 0.002949911999166943.  ETA 1.8799199188291096.Processed 300/127656 in 0.0040112409988068976.  ETA 1.702852028813504.Processed 400/127656 in 0.005028456998843467.  ETA 1.5997533096120606.Processed 500/127656 in 0.006174559999635676.  ETA 1.570264702627348.Processed 600/127656 in 0.007523699998273514.  ETA 1.5932187116343994.Processed 700/127656 in 0.008814891996735241.  ETA 1.5987191833393133.Processed 800/127656 in 0.010480343997187447.  ETA 1.6618681476340134.Processed 900/127656 in 0.011715882999851601.  ETA 1.6500649616990994.Processed 1000/127656 in 0.012942059998749755.  ETA 1.639189551201649.Processed 1100/127656 in 0.014004127999214688.  ETA 1.6111876573351038.Processed 1200/127656 in 0.015225839997583535.  ETA 1.604499018945353.Processed 1300/127656 in 0.01636544299981324.  ETA 1.5906707043726167.Processed 1400/127656 in 0.017576615999132628.  ETA 1.585109449704635.Proce

Processed 34300/127656 in 0.4051993719986058.  ETA 1.102851095402386.Processed 34400/127656 in 0.4067051599995466.  ETA 1.102549313980166.Processed 34500/127656 in 0.4076689269968483.  ETA 1.100777001835316.Processed 34600/127656 in 0.40897064799719374.  ETA 1.0999182838158053.Processed 34700/127656 in 0.4102798529966094.  ETA 1.0990770609554128.Processed 34800/127656 in 0.41138581199993496.  ETA 1.0976908321570678.Processed 34900/127656 in 0.41270635899854824.  ETA 1.0968765339618722.Processed 35000/127656 in 0.4140616249969753.  ETA 1.096151255020564.Processed 35100/127656 in 0.4154866459975892.  ETA 1.0956063249844121.Processed 35200/127656 in 0.4166573709990189.  ETA 1.0943884628717413.Processed 35300/127656 in 0.41785006299687666.  ETA 1.0932283404572107.Processed 35400/127656 in 0.419145566997031.  ETA 1.0923359725671777.Processed 35500/127656 in 0.420325059996685.  ETA 1.0911401754663241.Processed 35600/127656 in 0.421360245996766.  ETA 1.0895713147606263.Processe

Processed 51300/127656 in 0.6070889319998969.  ETA 0.9036039472082675.Processed 51400/127656 in 0.6086221890000161.  ETA 0.9029395650658605.Processed 51500/127656 in 0.6098809439972683.  ETA 0.9018658868166207.Processed 51600/127656 in 0.6112809239966737.  ETA 0.9009996503002135.Processed 51700/127656 in 0.6124522859972785.  ETA 0.899795470700373.Processed 51800/127656 in 0.6138471709964506.  ETA 0.8989187452337212.Processed 51900/127656 in 0.6150060569998459.  ETA 0.8976955463213936.Processed 52000/127656 in 0.6161128599997028.  ETA 0.8963968180026445.Processed 52100/127656 in 0.617466281997622.  ETA 0.8954564760578182.Processed 52200/127656 in 0.6186918229977891.  ETA 0.8943296972437007.Processed 52300/127656 in 0.6198797079996439.  ETA 0.8931482844363511.Processed 52400/127656 in 0.6211322699964512.  ETA 0.8920597349399415.Processed 52500/127656 in 0.6222449729975779.  ETA 0.8907703464877326.Processed 52600/127656 in 0.62349964799796.  ETA 0.8896842125500929.Processed

Processed 67900/127656 in 0.8079305399987788.  ETA 0.7110264705179238.Processed 68000/127656 in 0.8094261759979418.  ETA 0.7101048228725473.Processed 68100/127656 in 0.8107080909976503.  ETA 0.7089945824883416.Processed 68200/127656 in 0.8119708609992813.  ETA 0.7078671482635377.Processed 68300/127656 in 0.8131063349974283.  ETA 0.7066286913632117.Processed 68400/127656 in 0.8142990439991991.  ETA 0.7054401191698324.Processed 68500/127656 in 0.8155389179992198.  ETA 0.7042922661775453.Processed 68600/127656 in 0.8166628740000306.  ETA 0.7030443540371109.Processed 68700/127656 in 0.8178460549970623.  ETA 0.7018476276332868.Processed 68800/127656 in 0.8188193989990395.  ETA 0.7004714323762714.Processed 68900/127656 in 0.8200694489969464.  ETA 0.6993323736613147.Processed 69000/127656 in 0.8214405829967291.  ETA 0.6982959251631324.Processed 69100/127656 in 0.8227341349993367.  ETA 0.6971927642405378.Processed 69200/127656 in 0.8239166919993295.  ETA 0.6959952911490289.Proce

Processed 84800/127656 in 1.00882264199754.  ETA 0.5098361219981907.Processed 84900/127656 in 1.0101818129987805.  ETA 0.5087318444826368.Processed 85000/127656 in 1.0113229969974782.  ETA 0.5075175736461698.Processed 85100/127656 in 1.0124497749966395.  ETA 0.5062962705611868.Processed 85200/127656 in 1.0136812269993243.  ETA 0.5051273494540295.Processed 85300/127656 in 1.0146903669992753.  ETA 0.5038478919650797.Processed 85400/127656 in 1.0158025879973138.  ETA 0.5026200721125819.Processed 85500/127656 in 1.0168873659968085.  ETA 0.5013789918241106.Processed 85600/127656 in 1.01812265999979.  ETA 0.5002122265064388.Processed 85700/127656 in 1.0193834029996651.  ETA 0.4990577602830099.Processed 85800/127656 in 1.0204338179974002.  ETA 0.4978004415629275.Processed 85900/127656 in 1.0220051319993217.  ETA 0.4967968136410207.Processed 86000/127656 in 1.0233366519969422.  ETA 0.4956757159951701.Processed 86100/127656 in 1.0244299679980031.  ETA 0.49443916086091777.Processe

Processed 101800/127656 in 1.2101552099993569.  ETA 0.30736515824895255.Processed 101900/127656 in 1.2116748229964287.  ETA 0.30626002689986276.Processed 102000/127656 in 1.2130257409990008.  ETA 0.30511165108892513.Processed 102100/127656 in 1.2144120359989756.  ETA 0.30397173351606094.Processed 102200/127656 in 1.2154846929988707.  ETA 0.30275321276887723.Processed 102300/127656 in 1.216528892997303.  ETA 0.3015279238596248.Processed 102400/127656 in 1.2176470879967383.  ETA 0.30032123881294553.Processed 102500/127656 in 1.2188874609964842.  ETA 0.29914471189100056.Processed 102600/127656 in 1.220253909999883.  ETA 0.2979988495999714.Processed 102700/127656 in 1.221781733998796.  ETA 0.29689177170081743.Processed 102800/127656 in 1.222926019996521.  ETA 0.29569113962094873.Processed 102900/127656 in 1.2240247539994016.  ETA 0.29447965801758197.Processed 103000/127656 in 1.2252340529994399.  ETA 0.29329486224033197.Processed 103100/127656 in 1.2266586269979598.  ETA 0.29

Processed 118900/127656 in 1.4116885669973271.  ETA 0.10395916814658197.Processed 119000/127656 in 1.4133750679975492.  ETA 0.10280818982005703.Processed 119100/127656 in 1.414516808999906.  ETA 0.10161717731152978.Processed 119200/127656 in 1.4159670489971177.  ETA 0.10044813226778211.Processed 119300/127656 in 1.4171618199980003.  ETA 0.09926072227915583.Processed 119400/127656 in 1.4182569179974962.  ETA 0.09806640799821884.Processed 119500/127656 in 1.419191104996571.  ETA 0.09686127742554003.Processed 119600/127656 in 1.4203695729993342.  ETA 0.09567305418129295.Processed 119700/127656 in 1.4215263619989855.  ETA 0.09448340631632354.Processed 119800/127656 in 1.4231427869999607.  ETA 0.09332395437956337.Processed 119900/127656 in 1.4241156049974961.  ETA 0.0921221070255261.Processed 120000/127656 in 1.4253858699994453.  ETA 0.0909396185059646.Processed 120100/127656 in 1.4265810159995453.  ETA 0.08975225775930529.Processed 120200/127656 in 1.4276403459989524.  ETA 0.

Processed 100/127656 in 0.0011928750027436763.  ETA 1.5215836384997237.Processed 200/127656 in 0.002568114003224764.  ETA 1.6366076919750776.Processed 300/127656 in 0.0037596190013573505.  ETA 1.5960334584562224.Processed 400/127656 in 0.004892832002951764.  ETA 1.556605573419074.Processed 500/127656 in 0.005905448000703473.  ETA 1.5018262919549015.Processed 600/127656 in 0.007417988002998754.  ETA 1.5708331395150161.Processed 700/127656 in 0.008569351000915049.  ETA 1.5541864652459585.Processed 800/127656 in 0.009571388000040315.  ETA 1.5177349951663928.Processed 900/127656 in 0.010659529001713963.  ETA 1.5012880646013946.Processed 1000/127656 in 0.012011687002086546.  ETA 1.5213522289362735.Processed 1100/127656 in 0.013341041001694975.  ETA 1.5348988954640994.Processed 1200/127656 in 0.014620019002904883.  ETA 1.5406576025261165.Processed 1300/127656 in 0.015623823001078563.  ETA 1.5185875224032945.Processed 1400/127656 in 0.016696377002517693.  ETA 1.5057269820213384.

Processed 24900/127656 in 0.29952295500334003.  ETA 1.236055452382458.Processed 25000/127656 in 0.3006840660018497.  ETA 1.2346809391794353.Processed 25100/127656 in 0.30191694400127744.  ETA 1.233601358924104.Processed 25200/127656 in 0.3030818020015431.  ETA 1.2322440121377025.Processed 25300/127656 in 0.3043353430002753.  ETA 1.2312469710725762.Processed 25400/127656 in 0.3056896420021076.  ETA 1.2306535445892721.Processed 25500/127656 in 0.30693460200200207.  ETA 1.229616125573197.Processed 25600/127656 in 0.3081403540018073.  ETA 1.22842078000033.Processed 25700/127656 in 0.30951806099983514.  ETA 1.2279075263540542.Processed 25800/127656 in 0.31066342200210784.  ETA 1.2264702911413448.Processed 25900/127656 in 0.3118511190004938.  ETA 1.2252016395758396.Processed 26000/127656 in 0.3130897300034121.  ETA 1.2241326766625715.Processed 26100/127656 in 0.3142445140001655.  ETA 1.2227362399923682.Processed 26200/127656 in 0.3154422389998217.  ETA 1.2215079312964088.Proce

(127656, 200)

## That took too long - better save the results...
And provide for reloading them

In [40]:
# Hyper-paramters
layer_size = 1024
layer_count = 3
hidden_fc_layers = [100]
keep_prob_training = 0.5
learning_rate = 0.00001
epochs = 100
batch_size=64

checkpoint_path = 'a5cp1'

# Approach using improved embeddings

In [51]:
from random import shuffle

# A function to get the lists of inputs with each label
def get_label_lists(labels):
    labels_true = []
    offset = []
    for i in range(0,n_labels):
        labels_true.append([])
        offset.append(0)

    for i,label in enumerate(labels):
        for ii in range(0,n_labels):
            if label[ii] == 1:
                labels_true[ii].append(i)
    
    labels_true.append([i for i,label in enumerate(labels) if sum(label)==0])
    offset.append(0)
    
    for i in range(0,n_labels+1):
        shuffle(labels_true[i])
    
    return labels_true,offset
    
def get_batches(input_ints,labels,batch_size):
    
    num_inputs = len(input_ints)
    num_batches = num_inputs//batch_size
    
    labels_true,offset = get_label_lists(labels)
    group_size = batch_size // 20
    list_length = [len(labels_list) for labels_list in labels_true]
    
    for ii in range(0,num_batches):
        indicies = set()
        for i in range(0,n_labels+1):
            indicies.update([labels_true[i][ii % list_length[i]] for ii in range(offset[i],offset[i]+group_size) ])
            offset[i]+=group_size
            if offset[i]>=list_length[i]:
                offset[i]=0
                shuffle(labels_true[i])
             
        while len(indicies) < batch_size:
            indicies.add(labels_true[n_labels][offset[n_labels] % list_length[n_labels]])
            offset[n_labels]+=1
            if offset[n_labels] == list_length[n_labels]:
                offset[n_labels]=0
                shuffle(labels_true[n_labels])
            
        features = np.array([input_ints[i] for i in indicies])
        return_labels = np.array([labels[i] for i in indicies])
        yield features, return_labels

def get_test_batches(input_ints,labels,batch_size):
    num_inputs = len(input_ints)
    num_batches = num_inputs//batch_size
    if (num_inputs > num_batches * batch_size):
        num_batches += 1
        
    for ii in range(0,num_batches):
        end = ii * batch_size + batch_size if ii * batch_size + batch_size <= num_inputs else num_inputs - 1
        indicies = [0] * batch_size
        indicies[:end-ii*batch_size] = range(ii * batch_size,end)
        
        features = np.array([input_ints[i] for i in indicies])
        return_labels = np.array([labels[i] for i in indicies])
        yield features, return_labels
    

# Alright - enough with the pre-processing, let's build a network
Firstly, define placeholders and the embedding (only variable that is being explicitly defined).

Embedding is initialised by copying it from the list built earlier - only way I can find to do partial transfer learning and partial random.

In [42]:
# Building a graph and placeholders
graph = tf.Graph()

with graph.as_default():
    inputs_ = tf.placeholder(tf.int32,[None,comment_length],name='inputs')
    labels_ = tf.placeholder(tf.float32,[None,None],name='outputs')
    keep_prob_ = tf.placeholder(tf.float32,name='keep_prob')
    initial_embedding_ = tf.placeholder(tf.float32,[len_embedding,embed_size],name='embed')
    embedding_var = tf.Variable(tf.constant(0.0, shape=[len_embedding,embed_size]),
                trainable=True, name='embed_var')
    
    embedding_init = embedding_var.assign(initial_embedding_)

## Build the LSTM network
Just based on hyper-paratmeters defined earlier. Firstly define the LSTM network, then apply the embedding, finally get the categorical outputs.

In [43]:
# Build the LSTM network
with graph.as_default():
    lstm_cell = tf.contrib.rnn.BasicLSTMCell(num_units=layer_size)
    drop = tf.contrib.rnn.DropoutWrapper(cell=lstm_cell,input_keep_prob=keep_prob_)
    network = drop
    for _ in range(layer_count):
        network = tf.contrib.rnn.MultiRNNCell([network])

    initial_state = network.zero_state(batch_size,tf.float32)

In [44]:
# Forward pass
with graph.as_default():
    embed = tf.nn.embedding_lookup(embedding_var, inputs_)
    outputs, final_state = tf.nn.dynamic_rnn(network,embed,initial_state=initial_state)

In [45]:
# Get outputs
with graph.as_default():
    predictions = tf.contrib.layers.flatten(outputs)
    for size in hidden_fc_layers:
        predictions = tf.contrib.layers.fully_connected(predictions, size, activation_fn=tf.tanh)
        #predictions = tf.nn.leaky_relu(predictions,alpha=0.2)
        predictions = tf.nn.dropout(predictions,keep_prob_)
    predictions = tf.contrib.layers.fully_connected(predictions, n_labels, activation_fn=tf.sigmoid)
    cost = tf.losses.sigmoid_cross_entropy(labels_, predictions)
    
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)

Define TF functions to return accuracy stats on the test set - including a confusion matrix for each output

In [46]:
# Determine accuracy on test set
with graph.as_default():
    validation_metrics_var_scope = "validation_metrics"
    binary_pred = tf.cast(tf.round(predictions), tf.bool)
    binary_labels = tf.cast(labels_, tf.bool)
    accuracy = tf.reduce_sum(tf.cast(tf.equal(binary_pred,binary_labels),tf.int32))
    correct_pos = tf.reduce_sum(tf.cast(tf.logical_and(binary_pred,binary_labels),tf.int32),axis=0)
    false_pos = tf.reduce_sum(tf.cast(tf.logical_and(binary_pred,tf.logical_not(binary_labels)),tf.int32),axis=0)
    false_neg = tf.reduce_sum(tf.cast(tf.logical_and(tf.logical_not(binary_pred),binary_labels),tf.int32),axis=0)
    correct_neg = tf.reduce_sum(tf.cast(tf.logical_and(tf.logical_not(binary_pred),tf.logical_not(binary_labels)),tf.int32),axis=0)
    auc = tf.metrics.auc(labels=labels_,predictions=predictions,name=validation_metrics_var_scope)

# Train the model

In [None]:
#Training
with graph.as_default():
    saver = tf.train.Saver()
    
n_batches = len(comment_words)//batch_size

val_acc = []
false_pos_list = []

# Find a checkpoint from previous training - if there is one
last_checkpoint = tf.train.latest_checkpoint(checkpoint_path)

print("Starting...")

iteration = 0
with tf.Session(graph=graph) as sess:
    # Initialise variables, load the embeddings, reset the validation metrics
    sess.run(tf.global_variables_initializer())
    sess.run(embedding_init,feed_dict={initial_embedding_:embeddings})
    
    validation_metrics_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope=validation_metrics_var_scope)
    validation_metrics_init_op = tf.variables_initializer(var_list=validation_metrics_vars, name='validation_metrics_init')
    sess.run(validation_metrics_init_op)
    
    # Load a checkpoint if we found one
    if last_checkpoint != None:
        saver.restore(sess,last_checkpoint)
        print("Restored checkpoint from {}.".format(last_checkpoint))    
    
    for e in range(epochs):
        state = sess.run(initial_state)
            
        # Generate randomised, resampled batches
        for ii,(x,y) in enumerate(get_batches(comment_ints,labels_train,batch_size),1):
            feed = {inputs_:x,
                    labels_:y,
                    keep_prob_:keep_prob_training,
                    initial_state:state}
            
            # Do the training
            loss, state, _ = sess.run([cost,final_state,optimizer],feed_dict=feed)
            iteration += 1
            
            # Regular updates on progress
            if iteration%100==0:
                print("\rEpoch: {}/{}".format(e, epochs),
                      "Iteration: {}/{}".format(iteration, n_batches*epochs),
                      "Train loss: {:.3f}".format(loss),end='')

            # Run the test set - show confusion matricies (in columns)
            if iteration%2000==0:
                val_acc.clear()
                false_pos_list.clear()
                total_correct_pos = 0
                total_false_pos = 0
                total_correct_neg = 0
                total_false_neg = 0
                
                val_state = sess.run(initial_state)
                sess.run(validation_metrics_init_op)
                if iteration%2000==0:
                    test_subset_x,test_subset_y = test_comment_ints,labels_test
                else:
                    test_subset_x,test_subset_y = test_comment_ints[:4096],labels_test[:4096]
                
                for x, y in get_test_batches(test_subset_x, test_subset_y, batch_size):
                    feed = {inputs_: x,
                            labels_: y,
                            keep_prob_: 1,
                            initial_state: val_state}
                    
                    auc_val, n_correct_pos, n_correct_neg, n_false_pos, n_false_neg, val_state, batch_acc = sess.run([auc, correct_pos, correct_neg, false_pos, false_neg, final_state,accuracy], feed_dict=feed)
                    #print(predictions)
                    val_acc.append(batch_acc/len(test_subset_y))
                    auc_value = auc_val[1]
                    total_correct_pos += n_correct_pos
                    total_false_pos += n_false_pos
                    total_correct_neg += n_correct_neg
                    total_false_neg += n_false_neg
                print("During epoch {}".format(e))
                print("  Val acc      : {}".format(np.mean(val_acc)))
                print("  AuC          : {}".format(auc_value))
                print("  Correct pos  : {}".format('  '.join(['{:5}'.format(x) for x in total_correct_pos])))
                print("  False neg    : {}".format('  '.join(['{:5}'.format(x) for x in total_false_neg])))
                print("  Correct neg  : {}".format('  '.join(['{:5}'.format(x) for x in total_correct_neg])))
                print("  False pos    : {}\n".format('  '.join(['{:5}'.format(x) for x in total_false_pos])))
                
                # Save a checkpoint (indescrimnantly)
                saver.save(sess, "{}/epoch{}iter{}.ckpt".format(checkpoint_path,e,iteration))

Starting...
INFO:tensorflow:Restoring parameters from a5cp1/epoch36iter72000.ckpt
Restored checkpoint from a5cp1/epoch36iter72000.ckpt.
Epoch: 1/100 Iteration: 2000/199400 Train loss: 0.648During epoch 1
  Val acc      : 0.011783491783818302
  AuC          : 0.8670397400856018
  Correct pos  :  1844     93   1238     21   1019     83
  False neg    :  1193    218    431     71    563    222
  Correct neg  : 28596  31526  29960  31808  29918  31553
  False pos    :   303     99    307     36    436     78

Epoch: 2/100 Iteration: 4000/199400 Train loss: 0.647During epoch 2
  Val acc      : 0.01177972426130657
  AuC          : 0.859678328037262
  Correct pos  :  1735    125   1223     28   1008     94
  False neg    :  1302    186    446     64    574    211
  Correct neg  : 28672  31479  29977  31805  29920  31533
  False pos    :   227    146    290     39    434     98

Epoch: 2/100 Iteration: 4100/199400 Train loss: 0.651

# Produce CSV for Kaggle
Finally, produce a CSV in the right format to upload to Kaggle.

In [50]:
submit_data = pd.read_csv('data/test.csv')

submit_comments = [clean_punc(comment) for comment in submit_data.comment_text]
submit_comment_ints = []
#submit_comments = submit_comments[:204] #For testing on a subset
for comment in submit_comments:
    words = [word for word in comment.split()]
    submit_comment_ints.append (process_comment(words))
    
label_placeholder = np.zeros([len(submit_comments),n_labels])
results = []

with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    
    last_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
    saver.restore(sess,last_checkpoint)

    for x, y in get_test_batches(submit_comment_ints, label_placeholder, batch_size):
        feed = {inputs_: x,
            keep_prob_: 1,
            initial_state: val_state}
        #print (x)
        pred, val_state = sess.run([predictions, final_state], feed_dict=feed)

        for the_pred in pred:
            results.append(the_pred)
        
        print("\rDone: {}/{}".format(len(results), len(label_placeholder)),end='')

results = results[:len(submit_comment_ints)]

submission = pd.concat([submit_data['id'],pd.DataFrame(results,columns=['toxic','severe_toxic','obscene','threat','insult','identity_hate'])],axis=1)

submission.to_csv('submission.csv',index=False, float_format='%.4f')

INFO:tensorflow:Restoring parameters from a5cp1/epoch36iter72000.ckpt
Done: 153216/153164

# General test commands

In [None]:
sum(labels_test[:1024])

In [None]:
word_counts['explanation']

In [None]:
[ len([word for word in comment if word in filtered_words])/len(comment) for comment in comment_words[0:10]]

In [None]:
'norman' in word_counts.keys()