In [2]:
# coding:utf-8
import numpy as np
import tensorflow as tf
import os
import time
import datetime
import ctypes

ll = ctypes.cdll.LoadLibrary
lib = ll("./init.so")


class Config(object):

    def __init__(self):
        self.L1_flag = True
        self.hidden_size = 50
        self.nbatches = 100
        self.entity = 0
        self.relation = 0
        self.trainTimes = 3000
        self.margin = 1.0


class TransEModel(object):

    def __init__(self, config):

        entity_total = config.entity
        relation_total = config.relation
        batch_size = config.batch_size
        size = config.hidden_size
        margin = config.margin

        self.pos_h = tf.placeholder(tf.int32, [None])
        self.pos_t = tf.placeholder(tf.int32, [None])
        self.pos_r = tf.placeholder(tf.int32, [None])

        self.neg_h = tf.placeholder(tf.int32, [None])
        self.neg_t = tf.placeholder(tf.int32, [None])
        self.neg_r = tf.placeholder(tf.int32, [None])

        with tf.name_scope("embedding"):
            self.ent_embeddings = tf.get_variable(name="ent_embedding", shape=[
                                                  entity_total, size], initializer=tf.contrib.layers.xavier_initializer(uniform=False))
            self.rel_embeddings = tf.get_variable(name="rel_embedding", shape=[
                                                  relation_total, size], initializer=tf.contrib.layers.xavier_initializer(uniform=False))
            pos_h_e = tf.nn.embedding_lookup(self.ent_embeddings, self.pos_h)
            pos_t_e = tf.nn.embedding_lookup(self.ent_embeddings, self.pos_t)
            pos_r_e = tf.nn.embedding_lookup(self.rel_embeddings, self.pos_r)
            neg_h_e = tf.nn.embedding_lookup(self.ent_embeddings, self.neg_h)
            neg_t_e = tf.nn.embedding_lookup(self.ent_embeddings, self.neg_t)
            neg_r_e = tf.nn.embedding_lookup(self.rel_embeddings, self.neg_r)

        if config.L1_flag:  # use L1-norm
            # TO DO
            pos = tf.reduce_sum(abs(pos_h_e + pos_r_e - pos_t_e), 1, keep_dims = True)
            neg = tf.reduce_sum(abs(neg_h_e + neg_r_e - neg_t_e), 1, keep_dims = True)
            

        else:  # use L2-norm
            # TO DO  
            pos = tf.reduce_sum((pos_h_e + pos_r_e - pos_t_e) ** 2, 1, keep_dims = True)
            neg = tf.reduce_sum((neg_h_e + neg_r_e - neg_t_e) ** 2, 1, keep_dims = True)

        with tf.name_scope("output"):
            # TO DO
            self.loss = tf.reduce_sum(tf.maximum(pos - neg + margin, 0))


def main(_):
    lib.init()
    config = Config()
    config.relation = lib.getRelationTotal()
    config.entity = lib.getEntityTotal()
    config.batch_size = lib.getTripleTotal() // config.nbatches

    with tf.Graph().as_default():
        config_gpu = tf.ConfigProto()
        config_gpu.gpu_options.allow_growth = True
        config_gpu.gpu_options.per_process_gpu_memory_fraction = 0.15
        sess = tf.Session(config=config_gpu)
        with sess.as_default():
            initializer = tf.contrib.layers.xavier_initializer(uniform=False)
            with tf.variable_scope("model", reuse=None, initializer=initializer):
                trainModel = TransEModel(config=config)

            global_step = tf.Variable(0, name="global_step", trainable=False)
            optimizer = tf.train.GradientDescentOptimizer(0.001)
            grads_and_vars = optimizer.compute_gradients(trainModel.loss)
            train_op = optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            saver = tf.train.Saver()
            sess.run(tf.initialize_all_variables())

            def train_step(pos_h_batch, pos_t_batch, pos_r_batch, neg_h_batch, neg_t_batch, neg_r_batch):
                feed_dict = {
                    trainModel.pos_h: pos_h_batch,
                    trainModel.pos_t: pos_t_batch,
                    trainModel.pos_r: pos_r_batch,
                    trainModel.neg_h: neg_h_batch,
                    trainModel.neg_t: neg_t_batch,
                    trainModel.neg_r: neg_r_batch
                }
                _, step, loss = sess.run(
                    [train_op, global_step, trainModel.loss], feed_dict)
                return loss

            ph = np.zeros(config.batch_size, dtype=np.int32)
            pt = np.zeros(config.batch_size, dtype=np.int32)
            pr = np.zeros(config.batch_size, dtype=np.int32)
            nh = np.zeros(config.batch_size, dtype=np.int32)
            nt = np.zeros(config.batch_size, dtype=np.int32)
            nr = np.zeros(config.batch_size, dtype=np.int32)

            ph_addr = ph.__array_interface__['data'][0]
            pt_addr = pt.__array_interface__['data'][0]
            pr_addr = pr.__array_interface__['data'][0]
            nh_addr = nh.__array_interface__['data'][0]
            nt_addr = nt.__array_interface__['data'][0]
            nr_addr = nr.__array_interface__['data'][0]

            for times in range(config.trainTimes):
                res = 0.0
                for batch in range(config.nbatches):
                    lib.getBatch(ph_addr, pt_addr, pr_addr, nh_addr,
                                 nt_addr, nr_addr, config.batch_size)
                    res += train_step(ph, pt, pr, nh, nt, nr)
                    current_step = tf.train.global_step(sess, global_step)
                print(times)
                print(res)
            #saver.save(sess, 'model.vec')
            # save the embeddings
            f = open("entity2vec.txt", "w")
            enb = trainModel.ent_embeddings.eval()
            for i in enb:
                for j in i:
                    f.write("%f\t" % (j))
                f.write("\n")
            f.close()

            f = open("relation2vec.txt", "w")
            enb = trainModel.rel_embeddings.eval()
            for i in enb:
                for j in i:
                    f.write("%f\t" % (j))
                f.write("\n")
            f.close()

if __name__ == "__main__":
    tf.app.run()


W0707 00:12:42.612470 140714021189440 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

W0707 00:12:43.341498 140714021189440 deprecation.py:506] From <ipython-input-2-b615b0f691a2>:57: calling reduce_sum_v1 (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
W0707 00:12:43.362505 140714021189440 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/math_grad.py:1424: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Us

0
152893.0701904297
1
111340.46826171875
2
94015.94030761719
3
82825.55688476562
4
73962.12188720703
5
66212.72399902344
6
59721.231506347656
7
54452.299865722656
8
49865.960357666016
9
45419.03890991211
10
41956.157287597656
11
38505.515563964844
12
35189.73654174805
13
33122.0871887207
14
30796.784271240234
15
28930.70477294922
16
26639.4884185791
17
25403.173614501953
18
23817.53973388672
19
22395.847412109375
20
21499.057205200195
21
19138.00375366211
22
18681.611099243164
23
17306.246871948242
24
17424.592666625977
25
14751.556167602539
26
14771.735450744629
27
15270.216850280762
28
14187.492080688477
29
13072.93431854248
30
12500.849433898926
31
11296.766502380371
32
10882.22282409668
33
10443.872787475586
34
10735.383033752441
35
9570.193962097168
36
10412.45142364502
37
9178.243312835693
38
9036.945571899414
39
8919.047386169434
40
8994.323341369629
41
8212.886234283447
42
8358.167343139648
43
7868.260467529297
44
7819.955307006836
45
7520.788333892822
46
7390.009742736816
47
7

370
1632.605917930603
371
1621.5435671806335
372
1670.0101833343506
373
1677.5275130271912
374
1659.4813165664673
375
1627.530556678772
376
1615.6079988479614
377
1640.3121347427368
378
1663.2757682800293
379
1639.472755432129
380
1589.5733127593994
381
1558.4776091575623
382
1645.1065740585327
383
1582.6795644760132
384
1695.0811405181885
385
1587.2214803695679
386
1578.6346521377563
387
1614.0993947982788
388
1562.9219036102295
389
1558.5800666809082
390
1661.6220426559448
391
1644.8283672332764
392
1569.7889289855957
393
1623.1664724349976
394
1580.4526205062866
395
1584.9667654037476
396
1608.279794216156
397
1622.033546447754
398
1557.0947523117065
399
1576.8093843460083
400
1556.6848764419556
401
1533.7006697654724
402
1489.4026446342468
403
1554.7786855697632
404
1558.120849609375
405
1604.8143558502197
406
1567.6746792793274
407
1536.5282649993896
408
1563.806287765503
409
1515.736762046814
410
1540.3081731796265
411
1505.6451706886292
412
1512.0568523406982
413
1554.3648319244

731
980.5770273208618
732
956.0090582370758
733
972.3925023078918
734
967.0402133464813
735
978.2797064781189
736
970.0863952636719
737
958.1922812461853
738
976.8210740089417
739
916.433525800705
740
961.9383249282837
741
976.7963027954102
742
860.0417108535767
743
948.7682547569275
744
1001.2606382369995
745
966.9861068725586
746
960.0689463615417
747
983.9119629859924
748
981.639244556427
749
949.4208960533142
750
892.3194599151611
751
939.2110409736633
752
944.7451889514923
753
934.5276114940643
754
951.7165303230286
755
934.1708488464355
756
924.9560461044312
757
883.6998107433319
758
954.7983827590942
759
927.1498036384583
760
905.9345812797546
761
915.1142747402191
762
931.9847621917725
763
896.3249464035034
764
923.367772102356
765
909.8788111209869
766
881.1243019104004
767
924.3037967681885
768
916.1804609298706
769
916.6111936569214
770
907.0940706729889
771
922.1515779495239
772
944.4846515655518
773
915.7317314147949
774
897.1306748390198
775
867.5441036224365
776
910.5957

1101
673.6139445304871
1102
695.2532558441162
1103
672.6941251754761
1104
674.0227437019348
1105
690.9267828464508
1106
685.9986457824707
1107
705.6975905895233
1108
718.1899843215942
1109
662.1521005630493
1110
679.5902454853058
1111
665.6506869792938
1112
678.3736608028412
1113
659.8337466716766
1114
677.9172875881195
1115
647.3449950218201
1116
646.5081286430359
1117
663.1453931331635
1118
673.4524574279785
1119
659.8627562522888
1120
679.3329336643219
1121
718.9818849563599
1122
694.3398232460022
1123
671.032041311264
1124
671.4160323143005
1125
693.0081820487976
1126
667.9916014671326
1127
639.8889663815498
1128
661.8436558246613
1129
619.5335714817047
1130
691.2738556861877
1131
665.2493255138397
1132
650.857328414917
1133
694.6482274532318
1134
637.6739203929901
1135
659.3684055805206
1136
633.31871342659
1137
647.2163617610931
1138
659.9177906513214
1139
705.664888381958
1140
680.2851016521454
1141
649.1682744026184
1142
633.653578042984
1143
645.7479228973389
1144
651.33474421

1460
533.5140504837036
1461
543.8271518945694
1462
565.042445898056
1463
543.5869369506836
1464
506.0333483219147
1465
519.1219651699066
1466
561.8854262828827
1467
542.3181056976318
1468
531.2425235509872
1469
572.9316563606262
1470
518.1915526390076
1471
541.3240308761597
1472
527.0024154186249
1473
543.8216356039047
1474
552.3142130374908
1475
562.0584416389465
1476
538.9150593280792
1477
494.87068271636963
1478
558.152928352356
1479
540.656417965889
1480
497.19911909103394
1481
546.2646887302399
1482
541.047719836235
1483
516.7907810211182
1484
508.47220826148987
1485
567.5291240215302
1486
522.4784021377563
1487
542.5737357139587
1488
495.4472162723541
1489
498.97417306900024
1490
519.3006806373596
1491
519.5610032081604
1492
498.83872175216675
1493
537.7751824855804
1494
510.1755071878433
1495
510.10125207901
1496
534.3186085224152
1497
521.0798032283783
1498
514.5109832286835
1499
518.0313320159912
1500
476.37922501564026
1501
514.8043930530548
1502
536.5162165164948
1503
515.95

1813
464.3287695646286
1814
454.47440481185913
1815
428.70163345336914
1816
465.1143054962158
1817
452.84267687797546
1818
483.8273096084595
1819
478.18126034736633
1820
467.3574755191803
1821
446.3634605407715
1822
454.208261013031
1823
462.47463726997375
1824
444.37791752815247
1825
412.4912779331207
1826
462.0040748119354
1827
455.90887784957886
1828
440.894935131073
1829
463.5795089006424
1830
453.13642501831055
1831
499.9017086029053
1832
463.1851843595505
1833
435.2945816516876
1834
428.2716438770294
1835
415.60129976272583
1836
412.5886251926422
1837
420.9861843585968
1838
422.39156198501587
1839
442.8159086704254
1840
405.88973796367645
1841
420.6200613975525
1842
455.2730596065521
1843
455.187415599823
1844
426.60275316238403
1845
433.1636574268341
1846
459.1309766769409
1847
445.7981802225113
1848
446.27233052253723
1849
455.0011501312256
1850
446.7259306907654
1851
453.56259322166443
1852
430.1485915184021
1853
387.12153339385986
1854
444.50541138648987
1855
423.598267316818

2164
408.69953429698944
2165
383.2431137561798
2166
370.2318434715271
2167
386.45560598373413
2168
361.19766914844513
2169
393.33927071094513
2170
395.7063887119293
2171
398.3424357175827
2172
377.9420018196106
2173
373.52258574962616
2174
388.16914343833923
2175
358.49260568618774
2176
378.99052357673645
2177
373.28250336647034
2178
380.7937023639679
2179
370.32744312286377
2180
379.11138927936554
2181
388.71123629808426
2182
392.89462983608246
2183
364.6522536277771
2184
380.49633610248566
2185
400.84711813926697
2186
372.61164236068726
2187
368.7284119129181
2188
404.6823101043701
2189
372.80911135673523
2190
367.75850439071655
2191
379.3249444961548
2192
374.1017087697983
2193
365.62806129455566
2194
384.82531213760376
2195
365.89878821372986
2196
368.63748490810394
2197
352.4464771747589
2198
377.55101346969604
2199
353.05150747299194
2200
384.5572829246521
2201
412.48651576042175
2202
396.29030561447144
2203
376.36353600025177
2204
378.9025995731354
2205
373.7001863718033
2206
36

2514
320.5369015932083
2515
358.2177826166153
2516
328.29500019550323
2517
333.26999068260193
2518
314.50879776477814
2519
334.8066226243973
2520
321.1461843252182
2521
332.4798846244812
2522
334.81088972091675
2523
347.53647458553314
2524
318.985897064209
2525
318.1775152683258
2526
319.9069825410843
2527
329.91169595718384
2528
346.09006679058075
2529
313.18607687950134
2530
329.548233628273
2531
315.706139087677
2532
330.3185498714447
2533
349.0854215621948
2534
321.5238060951233
2535
335.54045581817627
2536
346.3821142911911
2537
330.69234442710876
2538
320.03781819343567
2539
327.77358746528625
2540
351.5548577308655
2541
332.3819351196289
2542
344.9369230866432
2543
360.16564083099365
2544
321.23806285858154
2545
337.07116544246674
2546
334.2294842004776
2547
325.46984016895294
2548
331.9368727207184
2549
325.30585765838623
2550
346.0506032705307
2551
308.8446161746979
2552
338.4485960006714
2553
360.79739141464233
2554
336.73764395713806
2555
336.91013526916504
2556
305.77853107

2865
298.47389221191406
2866
323.0226812362671
2867
311.2149838209152
2868
291.1055543422699
2869
289.13749146461487
2870
301.1128350496292
2871
311.4691321849823
2872
269.2311897277832
2873
305.1235430240631
2874
289.8701343536377
2875
307.8921103477478
2876
287.97775506973267
2877
297.9541575908661
2878
306.086083650589
2879
304.1364767551422
2880
324.1509132385254
2881
342.05188155174255
2882
281.9933491945267
2883
284.55906426906586
2884
303.7413960695267
2885
288.3144347667694
2886
293.999648809433
2887
291.81574606895447
2888
295.5288841724396
2889
304.85386753082275
2890
298.5183045864105
2891
294.84963941574097
2892
279.9303029179573
2893
327.99976801872253
2894
304.22736644744873
2895
300.0465611219406
2896
290.1069197654724
2897
301.2542531490326
2898
285.4576136469841
2899
315.0706160068512
2900
274.2670373916626
2901
286.5543122291565
2902
299.03628039360046
2903
307.5654112100601
2904
296.050705909729
2905
292.0908308029175
2906
293.2462123632431
2907
283.74222803115845
29

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [3]:
%tb

SystemExit: 