In [1]:
# coding:utf-8
import numpy as np
import tensorflow as tf
import os
import time
import datetime
import ctypes

ll = ctypes.cdll.LoadLibrary
lib = ll("./init.so")


class Config(object):

    def __init__(self):
        self.L1_flag = True
        self.hidden_size = 50
        self.nbatches = 100
        self.entity = 0
        self.relation = 0
        self.trainTimes = 3000
        self.margin = 1.0


class TransEModel(object):

    def __init__(self, config):

        entity_total = config.entity
        relation_total = config.relation
        batch_size = config.batch_size
        size = config.hidden_size
        margin = config.margin

        self.pos_h = tf.placeholder(tf.int32, [None])
        self.pos_t = tf.placeholder(tf.int32, [None])
        self.pos_r = tf.placeholder(tf.int32, [None])

        self.neg_h = tf.placeholder(tf.int32, [None])
        self.neg_t = tf.placeholder(tf.int32, [None])
        self.neg_r = tf.placeholder(tf.int32, [None])

        with tf.name_scope("embedding"):
            self.ent_embeddings = tf.get_variable(name="ent_embedding", shape=[
                                                  entity_total, size], initializer=tf.contrib.layers.xavier_initializer(uniform=False))
            self.rel_embeddings = tf.get_variable(name="rel_embedding", shape=[
                                                  relation_total, size], initializer=tf.contrib.layers.xavier_initializer(uniform=False))
            pos_h_e = tf.nn.embedding_lookup(self.ent_embeddings, self.pos_h)
            pos_t_e = tf.nn.embedding_lookup(self.ent_embeddings, self.pos_t)
            pos_r_e = tf.nn.embedding_lookup(self.rel_embeddings, self.pos_r)
            neg_h_e = tf.nn.embedding_lookup(self.ent_embeddings, self.neg_h)
            neg_t_e = tf.nn.embedding_lookup(self.ent_embeddings, self.neg_t)
            neg_r_e = tf.nn.embedding_lookup(self.rel_embeddings, self.neg_r)

        positive = pos_h_e + pos_r_e - pos_t_e
        negative = neg_h_e + neg_r_e - neg_t_e
        if config.L1_flag:  # use L1-norm
            pos_e = tf.norm((positive), ord=1, axis=1, keepdims=None)
            neg_e = tf.norm((negative), ord=1, axis=1, keepdims=None)

        else:  # use L2-norm
            pos_e = tf.norm((positive), ord=2, axis=1, keepdims=None)
            neg_e = tf.norm((negative), ord=2, axis=1, keepdims=None)

        with tf.name_scope("output"):
            self.loss = tf.reduce_sum(tf.math.maximum(0.0,(margin + pos_e - neg_e)))


def main(_):
    lib.init()
    config = Config()
    config.relation = lib.getRelationTotal()
    config.entity = lib.getEntityTotal()
    config.batch_size = lib.getTripleTotal() // config.nbatches

    with tf.Graph().as_default():
        config_gpu = tf.ConfigProto()
        config_gpu.gpu_options.allow_growth = True
        config_gpu.gpu_options.per_process_gpu_memory_fraction = 0.15
        sess = tf.Session(config=config_gpu)
        with sess.as_default():
            initializer = tf.contrib.layers.xavier_initializer(uniform=False)
            with tf.variable_scope("model", reuse=None, initializer=initializer):
                trainModel = TransEModel(config=config)

            global_step = tf.Variable(0, name="global_step", trainable=False)
            optimizer = tf.train.GradientDescentOptimizer(0.001)
            grads_and_vars = optimizer.compute_gradients(trainModel.loss)
            train_op = optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            saver = tf.train.Saver()
            sess.run(tf.initialize_all_variables())

            def train_step(pos_h_batch, pos_t_batch, pos_r_batch, neg_h_batch, neg_t_batch, neg_r_batch):
                feed_dict = {
                    trainModel.pos_h: pos_h_batch,
                    trainModel.pos_t: pos_t_batch,
                    trainModel.pos_r: pos_r_batch,
                    trainModel.neg_h: neg_h_batch,
                    trainModel.neg_t: neg_t_batch,
                    trainModel.neg_r: neg_r_batch
                }
                _, step, loss = sess.run(
                    [train_op, global_step, trainModel.loss], feed_dict)
                return loss

            ph = np.zeros(config.batch_size, dtype=np.int32)
            pt = np.zeros(config.batch_size, dtype=np.int32)
            pr = np.zeros(config.batch_size, dtype=np.int32)
            nh = np.zeros(config.batch_size, dtype=np.int32)
            nt = np.zeros(config.batch_size, dtype=np.int32)
            nr = np.zeros(config.batch_size, dtype=np.int32)

            ph_addr = ph.__array_interface__['data'][0]
            pt_addr = pt.__array_interface__['data'][0]
            pr_addr = pr.__array_interface__['data'][0]
            nh_addr = nh.__array_interface__['data'][0]
            nt_addr = nt.__array_interface__['data'][0]
            nr_addr = nr.__array_interface__['data'][0]

            for times in range(config.trainTimes):
                res = 0.0
                for batch in range(config.nbatches):
                    lib.getBatch(ph_addr, pt_addr, pr_addr, nh_addr,
                                 nt_addr, nr_addr, config.batch_size)
                    res += train_step(ph, pt, pr, nh, nt, nr)
                    current_step = tf.train.global_step(sess, global_step)
                print(times)
                print(res)
            #saver.save(sess, 'model.vec')
            # save the embeddings
            f = open("entity2vec.txt", "w")
            enb = trainModel.ent_embeddings.eval()
            for i in enb:
                for j in i:
                    f.write("%f\t" % (j))
                f.write("\n")
            f.close()

            f = open("relation2vec.txt", "w")
            enb = trainModel.rel_embeddings.eval()
            for i in enb:
                for j in i:
                    f.write("%f\t" % (j))
                f.write("\n")
            f.close()

if __name__ == "__main__":
    tf.app.run()


W0707 01:25:00.688051 139630753994560 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

W0707 01:25:01.412189 139630753994560 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/math_grad.py:1424: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
W0707 01:25:01.556650 139630753994560 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/util/tf_should_use.py:198: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and wi

0
153136.091796875
1
111394.43627929688
2
93960.99505615234
3
82431.79760742188
4
73670.90869140625
5
66250.49389648438
6
59765.644104003906
7
54176.269287109375
8
49686.131896972656
9
45061.63787841797
10
41255.246978759766
11
38083.96517944336
12
35420.943786621094
13
33063.66958618164
14
30739.577407836914
15
28596.232162475586
16
26875.080017089844
17
25425.42027282715
18
23463.787567138672
19
22169.85464477539
20
20402.131301879883
21
20425.737731933594
22
17864.285675048828
23
17734.976928710938
24
16703.207122802734
25
15792.953720092773
26
15128.553039550781
27
15249.072441101074
28
14054.22737121582
29
13806.29702758789
30
12466.371360778809
31
11608.475173950195
32
10835.784759521484
33
11137.458786010742
34
11168.357498168945
35
10129.207778930664
36
10328.999198913574
37
10453.627151489258
38
9075.835006713867
39
8919.89705657959
40
9108.845069885254
41
8276.48684310913
42
8604.597175598145
43
7868.951354980469
44
7877.340782165527
45
7763.176574707031
46
7490.486095428467


371
1608.9666738510132
372
1665.7608575820923
373
1714.7933883666992
374
1655.1037693023682
375
1640.9911451339722
376
1631.6547946929932
377
1625.1273593902588
378
1650.3245391845703
379
1623.9319257736206
380
1608.6372699737549
381
1566.2220792770386
382
1645.09672164917
383
1615.6641054153442
384
1709.6736249923706
385
1606.1744627952576
386
1591.997121810913
387
1638.5818529129028
388
1554.7525339126587
389
1541.8754625320435
390
1678.8680934906006
391
1635.3175210952759
392
1585.7892866134644
393
1642.3057641983032
394
1563.3064308166504
395
1591.526177406311
396
1580.5128316879272
397
1615.390085220337
398
1556.8050327301025
399
1603.9121885299683
400
1575.878128528595
401
1518.8098015785217
402
1523.478747844696
403
1565.5398206710815
404
1529.7842855453491
405
1611.537838459015
406
1559.4663982391357
407
1557.5410351753235
408
1562.528685092926
409
1501.6682090759277
410
1512.2467217445374
411
1493.960970878601
412
1508.9407720565796
413
1563.8317008018494
414
1537.981499195098

733
985.9581110477448
734
967.1708836555481
735
961.5832147598267
736
995.1449036598206
737
973.9891195297241
738
982.85462641716
739
926.5550246238708
740
978.593184709549
741
995.743757724762
742
892.0733859539032
743
962.6152300834656
744
969.0200848579407
745
954.166271686554
746
983.6941840648651
747
1001.9994492530823
748
962.7356696128845
749
939.2438931465149
750
916.8369734287262
751
937.1690351963043
752
955.9821846485138
753
951.3771932125092
754
955.9498224258423
755
940.2235445976257
756
939.855217218399
757
896.3851611614227
758
966.7961888313293
759
928.7610504627228
760
910.6636748313904
761
951.2074794769287
762
929.3035752773285
763
922.0692417621613
764
935.4968852996826
765
918.0063443183899
766
898.2458384037018
767
941.7445042133331
768
930.3362865447998
769
922.5179531574249
770
924.1181445121765
771
932.782717704773
772
956.7604746818542
773
939.118973493576
774
908.4971241950989
775
869.433798789978
776
915.1312127113342
777
945.1537976264954
778
922.1073186397

1103
684.3826644420624
1104
672.6463468074799
1105
697.9555473327637
1106
680.3435666561127
1107
734.7719392776489
1108
700.0692181587219
1109
670.8973083496094
1110
675.444171667099
1111
653.7331869602203
1112
702.593065738678
1113
671.3680062294006
1114
711.6766905784607
1115
656.1021254062653
1116
667.9701290130615
1117
710.2419912815094
1118
665.4951145648956
1119
661.4276320934296
1120
669.514529466629
1121
709.2143032550812
1122
679.7025806903839
1123
663.1335668563843
1124
696.1192960739136
1125
699.1324253082275
1126
685.0906610488892
1127
656.2499353885651
1128
667.8684763908386
1129
623.068089723587
1130
690.0934784412384
1131
678.010984659195
1132
655.4458539485931
1133
694.4422545433044
1134
614.5788261890411
1135
664.3386099338531
1136
646.1393489837646
1137
657.354131937027
1138
661.7748320102692
1139
681.7597608566284
1140
691.4719405174255
1141
671.9209957122803
1142
640.9144175052643
1143
675.1851358413696
1144
669.7251477241516
1145
668.2996814250946
1146
639.91326022

1462
556.7673029899597
1463
544.8536469936371
1464
519.1002336740494
1465
522.891676902771
1466
553.8289213180542
1467
539.7108025550842
1468
543.6498649120331
1469
567.9811680316925
1470
543.4921782016754
1471
537.3418021202087
1472
528.9924796819687
1473
544.7471685409546
1474
570.5912871360779
1475
557.7403330802917
1476
534.5754148960114
1477
511.46537375450134
1478
556.4693541526794
1479
541.5403370857239
1480
506.3707287311554
1481
543.6777672767639
1482
548.6301312446594
1483
532.1360671520233
1484
520.0037155151367
1485
564.8628430366516
1486
522.8381190299988
1487
549.6103947162628
1488
548.617201089859
1489
508.70778489112854
1490
544.9699528217316
1491
522.9338614940643
1492
508.532062292099
1493
533.2106807231903
1494
517.7381888628006
1495
500.36269974708557
1496
541.515456199646
1497
542.5780363082886
1498
520.909500837326
1499
526.4422173500061
1500
481.20767521858215
1501
522.5265266895294
1502
540.7851114273071
1503
507.79767775535583
1504
507.0364110469818
1505
526.27

1815
434.4432816505432
1816
464.52587699890137
1817
459.8557620048523
1818
477.64455342292786
1819
476.5507171154022
1820
458.91295433044434
1821
406.0067901611328
1822
475.82313776016235
1823
476.6428551673889
1824
454.49946546554565
1825
426.5089545249939
1826
457.19789814949036
1827
471.1912434101105
1828
461.13189697265625
1829
461.6570737361908
1830
435.24587631225586
1831
498.77597975730896
1832
459.50211906433105
1833
443.2243422269821
1834
419.83968675136566
1835
430.209020614624
1836
438.7437970638275
1837
436.52932620048523
1838
438.85408425331116
1839
442.0794451236725
1840
419.0983030796051
1841
434.744428396225
1842
441.0959720611572
1843
463.0943212509155
1844
418.2532238960266
1845
443.7123705148697
1846
451.60686230659485
1847
449.15200686454773
1848
451.1569411754608
1849
467.77079343795776
1850
443.7176761627197
1851
459.26059210300446
1852
422.9808804988861
1853
402.0205545425415
1854
439.5363861322403
1855
429.6710567474365
1856
434.2785736322403
1857
442.6056172847

2167
388.6172049045563
2168
367.28669595718384
2169
400.60706329345703
2170
400.1674702167511
2171
396.28644704818726
2172
388.823237657547
2173
366.9140865802765
2174
375.84032702445984
2175
367.72355341911316
2176
361.7238528728485
2177
382.6979024410248
2178
398.8362033367157
2179
371.80009865760803
2180
379.07008719444275
2181
396.7484481334686
2182
405.9336440563202
2183
397.8253698348999
2184
377.4709255695343
2185
399.64312767982483
2186
367.9239729642868
2187
374.67821061611176
2188
382.3827588558197
2189
369.4265847206116
2190
362.94954538345337
2191
365.8919219970703
2192
373.9454473257065
2193
361.2283179759979
2194
382.01716661453247
2195
379.6703749895096
2196
391.9294865131378
2197
364.94468200206757
2198
375.71385407447815
2199
364.5686638355255
2200
393.27286207675934
2201
393.5620183944702
2202
400.9240291118622
2203
379.6693090200424
2204
409.42747020721436
2205
362.8849573135376
2206
378.60497093200684
2207
406.52184987068176
2208
366.9896889925003
2209
363.544419288

2518
329.5192952156067
2519
349.0692837238312
2520
320.19253039360046
2521
341.308956861496
2522
337.6035370826721
2523
351.7607789039612
2524
338.04754650592804
2525
334.4274626970291
2526
308.6233617067337
2527
321.2094922065735
2528
343.6952528953552
2529
322.9552993774414
2530
353.55317306518555
2531
329.55971121788025
2532
346.8410789966583
2533
351.7731878757477
2534
320.831756234169
2535
322.8629596233368
2536
346.30174493789673
2537
327.5897636413574
2538
316.8784724473953
2539
344.4427601099014
2540
346.3708031177521
2541
323.92150926589966
2542
337.4832841157913
2543
358.4901216030121
2544
336.31727623939514
2545
337.93859565258026
2546
322.8258287906647
2547
337.1532961130142
2548
335.8251123428345
2549
340.06938648223877
2550
357.1551525592804
2551
313.96308970451355
2552
338.6814241409302
2553
343.6737072467804
2554
327.4745452404022
2555
334.4969482421875
2556
315.7833796739578
2557
339.49459397792816
2558
357.02348136901855
2559
312.34874653816223
2560
349.6209968328476


2869
310.5095372200012
2870
310.0384347438812
2871
316.50430178642273
2872
269.6454792022705
2873
310.370866894722
2874
283.99405884742737
2875
308.32096683979034
2876
298.1288220882416
2877
303.40760803222656
2878
305.1193903684616
2879
311.11692583560944
2880
327.8550908565521
2881
346.90716004371643
2882
300.74991941452026
2883
290.2442650794983
2884
314.47114872932434
2885
298.3680098056793
2886
315.09990882873535
2887
306.2955356836319
2888
279.53307950496674
2889
304.9697823524475
2890
313.0602536201477
2891
291.1197930574417
2892
290.8409478664398
2893
322.44552659988403
2894
315.09170365333557
2895
299.4721932411194
2896
291.2019441127777
2897
304.08532679080963
2898
311.3219188451767
2899
301.6554720401764
2900
293.5991904735565
2901
295.8939301967621
2902
312.1742088794708
2903
316.22761511802673
2904
294.1239045858383
2905
302.51115441322327
2906
301.71163725852966
2907
295.34121322631836
2908
322.65603733062744
2909
287.5502247810364
2910
289.87136721611023
2911
322.0252156

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
