In [1]:
%%bash 
rm -r /tmp/tensorflow/

In [2]:
import tensorflow as tf
import numpy as np
import pandas as pd
import time

def L2_dist(a, b):
    return tf.sqrt(tf.reduce_sum(tf.square(a-b)))

def margin_cost(pos, neg, margin=1.):
    out = margin + pos - neg
    # grad is non-zero only within a certain margin
    in_margin = tf.to_float(tf.greater(out, 0))
    return tf.reduce_sum(out * in_margin)

triples = pd.read_csv('walmart.nt', 
                      delim_whitespace=True, 
                      names=['s', 'p', 'o'])

entities       = set(triples.s.values) | set(triples.o.values)
predicates     = set(triples.p.values)
entity_dict    = dict(enumerate(entities))
predicate_dict = dict(enumerate(predicates))

entity_ids = dict([(v,k) for k,v in entity_dict.items()])
predicate_ids = dict([(v,k) for k,v in predicate_dict.items()])

dim = 8
bound = 6/np.sqrt(dim)
batch_size = 1024

X = np.array([triples.s.apply(lambda k: entity_ids[k]).values,
              triples.p.apply(lambda k: predicate_ids[k]).values,
              triples.o.apply(lambda k: entity_ids[k]).values]).T

In [3]:
with tf.name_scope('initialize_embeddings') as scope:
    entity_embeddings    = tf.Variable(tf.random_uniform([len(entities),dim],-bound, +bound),
                                       name='entity_embeddings')

    predicate_embeddings = tf.Variable(tf.random_uniform([len(predicates),dim],-bound, +bound),
                                       name='predicate_embeddings')
    
with tf.name_scope('read_inputs') as scope:
    pos_head = tf.placeholder(tf.int32, [batch_size], name='positive_head')
    pos_tail = tf.placeholder(tf.int32, [batch_size], name='positive_tail')
    neg_head = tf.placeholder(tf.int32, [batch_size], name='corrupted_head')
    neg_tail = tf.placeholder(tf.int32, [batch_size], name='corrupted_tail')
    link     = tf.placeholder(tf.int32, [batch_size], name='link')

with tf.name_scope('lookup_embeddings') as scope:
    pos_head_vec = tf.nn.embedding_lookup(entity_embeddings, pos_head)
    pos_tail_vec = tf.nn.embedding_lookup(entity_embeddings, pos_tail)
    neg_head_vec = tf.nn.embedding_lookup(entity_embeddings, neg_head)
    neg_tail_vec = tf.nn.embedding_lookup(entity_embeddings, neg_tail)
    link_vec     = tf.nn.embedding_lookup(predicate_embeddings, link)
    
with tf.name_scope('normalize_embeddings') as scope:
    pos_head_vec = tf.nn.l2_normalize(pos_head_vec, 1)
    pos_tail_vec = tf.nn.l2_normalize(pos_tail_vec, 1)
    neg_head_vec = tf.nn.l2_normalize(neg_head_vec, 1)
    neg_tail_vec = tf.nn.l2_normalize(neg_tail_vec, 1)

with tf.name_scope('train') as scope:
    # compute loss for true and corrupted triple
    pos_dist = L2_dist(tf.add(pos_head_vec, link_vec), pos_tail_vec)
    neg_dist = L2_dist(tf.add(neg_head_vec, link_vec), neg_tail_vec)
    diff = neg_dist - pos_dist
    loss = margin_cost(pos_dist, neg_dist, np.sqrt(dim))
    train = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
    
#with tf.name_scope('report') as scope:
pos_hist = tf.histogram_summary('distance_true', pos_dist/batch_size)
neg_hist = tf.histogram_summary('distance_corrupt', neg_dist/batch_size)
diff_hist = tf.histogram_summary('distance_diff', diff/batch_size)
loss_hist = tf.histogram_summary('loss', loss/batch_size)

sess = tf.Session()
merged = tf.merge_all_summaries()
writer = tf.python.training.summary_io.SummaryWriter("/tmp/tensorflow", sess.graph_def)
init = tf.initialize_all_variables()
sess.run(init)

In [4]:
last_epoch = 0

In [5]:
def corrupt(mat):
    out = mat.copy()
    new_entities = np.random.random_integers(0,len(entities)-1,len(mat))
    mask = (np.random.rand(len(mat))+1/2).astype(np.int)
    inv_mask = np.abs(mask-1)
    out[:,0] *= mask
    out[:,2] *= inv_mask
    out[:,0] += inv_mask * new_entities
    out[:,2] += mask * new_entities
    return out

In [13]:
n_batches = int(len(X)/batch_size)
for epoch in range(last_epoch, last_epoch+50):
    epoch_start = time.time()
    #for batch_num, batch_indices in enumerate(batches):
    X = X[np.random.permutation(X.shape[0])]
    for i in range(n_batches):
        sample = X[i*batch_size:(i+1)*batch_size]
        #print(sample.shape)
        corrupt_sample = corrupt(sample)
        #print(corrupt_sample.shape)
        feed  = {pos_head: sample[:,0],
                 link    : sample[:,1],
                 pos_tail: sample[:,2],
                 neg_head: corrupt_sample[:,0],
                 neg_tail: corrupt_sample[:,2]}
        
        sess.run(train, feed_dict=feed)
        
        if epoch > 0 and i % 100 == 0:
            result = sess.run(merged, feed_dict=feed)
            writer.add_summary(result, i*n_batches + epoch*n_batches)
        
        elapsed = time.time() - epoch_start
        remaining = (n_batches - i) * (elapsed / (1.0 + i))
        print('Batch: {:d}/{:d}, ETA: {:.0f} s'.format(i, n_batches,remaining), end='\r')

    #print('Epoch {:d} took: {:.0f} s | Loss:'.format(epoch, elapsed), loss_str)
    print('Epoch {:d} took: {:.0f} s'.format(epoch, elapsed))
    last_epoch = epoch

Epoch 30 took: 12 s
Epoch 31 took: 12 s
Epoch 32 took: 12 s
Epoch 33 took: 12 s
Epoch 34 took: 12 s
Epoch 35 took: 12 s
Epoch 36 took: 12 s
Epoch 37 took: 12 s
Epoch 38 took: 12 s
Epoch 39 took: 12 s
Epoch 40 took: 12 s
Epoch 41 took: 12 s
Epoch 42 took: 12 s
Epoch 43 took: 12 s
Epoch 44 took: 12 s
Epoch 45 took: 12 s
Epoch 46 took: 12 s
Epoch 47 took: 12 s
Epoch 48 took: 12 s
Epoch 49 took: 12 s
Epoch 50 took: 12 s
Epoch 51 took: 12 s
Epoch 52 took: 12 s
Epoch 53 took: 12 s
Epoch 54 took: 12 s
Epoch 55 took: 12 s
Epoch 56 took: 12 s
Epoch 57 took: 12 s
Epoch 58 took: 12 s
Epoch 59 took: 12 s
Epoch 60 took: 12 s
Epoch 61 took: 12 s
Epoch 62 took: 12 s
Epoch 63 took: 12 s
Epoch 64 took: 12 s
Epoch 65 took: 12 s
Epoch 66 took: 13 s
Epoch 67 took: 12 s
Epoch 68 took: 12 s
Epoch 69 took: 12 s
Epoch 70 took: 12 s
Epoch 71 took: 13 s
Epoch 72 took: 13 s
Epoch 73 took: 12 s
Epoch 74 took: 12 s
Epoch 75 took: 12 s
Epoch 76 took: 12 s
Epoch 77 took: 13 s
Epoch 78 took: 12 s
Epoch 79 took: 12 s


In [10]:
sess.run(pos_head_vec, feed)

array([[ 0.61396319, -0.34390607,  0.36752665, ...,  0.24422398,
         0.23334531, -0.41245151],
       [ 0.24910051, -0.38901827, -0.54924905, ..., -0.29962975,
        -0.25688305, -0.23794715],
       [-0.5913465 ,  0.27450365,  0.12143519, ...,  0.04485816,
        -0.38927367,  0.58995283],
       ..., 
       [ 0.01336615,  0.53853577,  0.24134755, ..., -0.20717558,
         0.18626279,  0.48572305],
       [ 0.19225186,  0.18832542,  0.47431207, ...,  0.2303036 ,
        -0.52475494,  0.26388609],
       [ 0.16205913, -0.02751728,  0.06905802, ..., -0.45021504,
         0.40441337,  0.23357819]], dtype=float32)

In [11]:
sample

array([[317751,      2,  38494],
       [290065,      1, 290587],
       [157825,      1, 121427],
       ..., 
       [128788,      0, 221019],
       [289628,      0, 142268],
       [165740,      1, 152428]])

In [12]:
X.shape

(3956563, 3)