In [2]:
# coding:utf-8
import numpy as np
import tensorflow as tf
import os
import time
import datetime
import ctypes

ll = ctypes.cdll.LoadLibrary
lib = ll("./init.so")


class Config(object):

    def __init__(self):
        self.L1_flag = True
        self.hidden_size = 50
        self.nbatches = 100
        self.entity = 0
        self.relation = 0
        self.trainTimes = 3000
        self.margin = 1.0


class TransRModel(object):

    def __init__(self, config):

        entity_total = config.entity
        relation_total = config.relation
        batch_size = config.batch_size
        size = config.hidden_size
        margin = config.margin

        self.pos_h = tf.placeholder(tf.int32, [None])
        self.pos_t = tf.placeholder(tf.int32, [None])
        self.pos_r = tf.placeholder(tf.int32, [None])

        self.neg_h = tf.placeholder(tf.int32, [None])
        self.neg_t = tf.placeholder(tf.int32, [None])
        self.neg_r = tf.placeholder(tf.int32, [None])

        with tf.name_scope("embedding"):
            self.ent_embeddings = tf.get_variable(name="ent_embedding", shape=[
                                                  entity_total, size], initializer=tf.contrib.layers.xavier_initializer(uniform=False))
            self.rel_embeddings = tf.get_variable(name="rel_embedding", shape=[
                                                  relation_total, size], initializer=tf.contrib.layers.xavier_initializer(uniform=False))
            self.transfer_matrix = tf.get_variable(name="transfer_matrix", shape=[
                                                  relation_total, size], initializer=tf.contrib.layers.xavier_initializer(uniform=False))
            pos_h_e = tf.nn.embedding_lookup(self.ent_embeddings, self.pos_h)
            pos_t_e = tf.nn.embedding_lookup(self.ent_embeddings, self.pos_t)
            pos_r_e = tf.nn.embedding_lookup(self.rel_embeddings, self.pos_r)
            neg_h_e = tf.nn.embedding_lookup(self.ent_embeddings, self.neg_h)
            neg_t_e = tf.nn.embedding_lookup(self.ent_embeddings, self.neg_t)
            neg_r_e = tf.nn.embedding_lookup(self.rel_embeddings, self.neg_r)
        
        positive = pos_h_e + pos_r_e - pos_t_e
        negative = neg_h_e + neg_r_e - neg_t_e
        if config.L1_flag:  # use L1-norm
            pos_e = tf.norm((positive), ord=1, axis=1, keepdims=None)
            neg_e = tf.norm((negative), ord=1, axis=1, keepdims=None)

        else:  # use L2-norm
            pos_e = tf.norm((positive), ord=2, axis=1, keepdims=None)
            neg_e = tf.norm((negative), ord=2, axis=1, keepdims=None)

        with tf.name_scope("output"):
            self.loss = tf.reduce_sum(tf.math.maximum(0.0,(margin + pos_e - neg_e)))


def main(_):
    lib.init()
    config = Config()
    config.relation = lib.getRelationTotal()
    config.entity = lib.getEntityTotal()
    config.batch_size = lib.getTripleTotal() // config.nbatches

    with tf.Graph().as_default():
        config_gpu = tf.ConfigProto()
        config_gpu.gpu_options.allow_growth = True
        config_gpu.gpu_options.per_process_gpu_memory_fraction = 0.15
        sess = tf.Session(config=config_gpu)
        with sess.as_default():
            initializer = tf.contrib.layers.xavier_initializer(uniform=False)
            with tf.variable_scope("model", reuse=None, initializer=initializer):
                trainModel = TransRModel(config=config)

            global_step = tf.Variable(0, name="global_step", trainable=False)
            optimizer = tf.train.GradientDescentOptimizer(0.001)
            grads_and_vars = optimizer.compute_gradients(trainModel.loss)
            train_op = optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            saver = tf.train.Saver()
            sess.run(tf.initialize_all_variables())

            def train_step(pos_h_batch, pos_t_batch, pos_r_batch, neg_h_batch, neg_t_batch, neg_r_batch):
                feed_dict = {
                    trainModel.pos_h: pos_h_batch,
                    trainModel.pos_t: pos_t_batch,
                    trainModel.pos_r: pos_r_batch,
                    trainModel.neg_h: neg_h_batch,
                    trainModel.neg_t: neg_t_batch,
                    trainModel.neg_r: neg_r_batch
                }
                _, step, loss = sess.run(
                    [train_op, global_step, trainModel.loss], feed_dict)
                return loss

            ph = np.zeros(config.batch_size, dtype=np.int32)
            pt = np.zeros(config.batch_size, dtype=np.int32)
            pr = np.zeros(config.batch_size, dtype=np.int32)
            nh = np.zeros(config.batch_size, dtype=np.int32)
            nt = np.zeros(config.batch_size, dtype=np.int32)
            nr = np.zeros(config.batch_size, dtype=np.int32)

            ph_addr = ph.__array_interface__['data'][0]
            pt_addr = pt.__array_interface__['data'][0]
            pr_addr = pr.__array_interface__['data'][0]
            nh_addr = nh.__array_interface__['data'][0]
            nt_addr = nt.__array_interface__['data'][0]
            nr_addr = nr.__array_interface__['data'][0]

            for times in range(config.trainTimes):
                res = 0.0
                for batch in range(config.nbatches):
                    lib.getBatch(ph_addr, pt_addr, pr_addr, nh_addr,
                                 nt_addr, nr_addr, config.batch_size)
                    res += train_step(ph, pt, pr, nh, nt, nr)
                    current_step = tf.train.global_step(sess, global_step)
                print(times)
                print(res)
            #saver.save(sess, 'model.vec')
            # save the embeddings
            f = open("entity2vec_R.txt", "w")
            enb = trainModel.ent_embeddings.eval()
            for i in enb:
                for j in i:
                    f.write("%f\t" % (j))
                f.write("\n")
            f.close()

            f = open("relation2vec_R.txt", "w")
            enb = trainModel.rel_embeddings.eval()
            for i in enb:
                for j in i:
                    f.write("%f\t" % (j))
                f.write("\n")
            f.close()

if __name__ == "__main__":
    tf.app.run()


W0707 06:31:04.088330 140141773948736 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/math_grad.py:1424: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
W0707 06:31:04.227294 140141773948736 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/util/tf_should_use.py:198: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.


0
152846.43334960938
1
111451.2914428711
2
94623.50207519531
3
83639.50006103516
4
74837.52899169922
5
67176.48510742188
6
60574.228942871094
7
55124.23825073242
8
50494.572265625
9
45890.15505981445
10
42353.011779785156
11
39099.74728393555
12
35681.31671142578
13
33140.969970703125
14
30767.713958740234
15
28955.57553100586
16
27394.999893188477
17
25008.9210357666
18
24299.92742919922
19
22792.769073486328
20
20635.753189086914
21
19749.476455688477
22
18365.64012145996
23
17823.58888244629
24
17330.935134887695
25
15028.080612182617
26
14739.480697631836
27
14822.270599365234
28
14078.522079467773
29
13338.908470153809
30
12047.255882263184
31
11882.05728149414
32
11045.955070495605
33
11547.393310546875
34
10476.954063415527
35
10637.24878692627
36
9864.877983093262
37
10079.456970214844
38
9651.77417755127
39
8797.933807373047
40
9249.086891174316
41
8290.796516418457
42
8434.004013061523
43
7888.321113586426
44
8141.6207275390625
45
7691.645050048828
46
7395.803478240967
47
741

369
1586.8332042694092
370
1559.583384513855
371
1548.5815315246582
372
1620.3629140853882
373
1638.825773715973
374
1577.1077919006348
375
1568.5117120742798
376
1584.1252031326294
377
1560.4522647857666
378
1598.605538368225
379
1597.0581302642822
380
1538.9874897003174
381
1498.9261636734009
382
1589.899673461914
383
1523.458641052246
384
1650.9888734817505
385
1525.8341102600098
386
1520.6756582260132
387
1593.4380836486816
388
1534.0192251205444
389
1487.6833205223083
390
1583.7280769348145
391
1600.0983743667603
392
1503.4591088294983
393
1542.6943292617798
394
1522.7283420562744
395
1522.3437461853027
396
1531.1130986213684
397
1567.939064025879
398
1510.7263774871826
399
1513.5749044418335
400
1549.7379503250122
401
1476.9866652488708
402
1453.393154144287
403
1512.2036666870117
404
1499.3173990249634
405
1531.2271537780762
406
1505.3041219711304
407
1519.7339372634888
408
1504.4515271186829
409
1449.3098945617676
410
1455.0719738006592
411
1462.3265490531921
412
1447.449121952

732
929.9706692695618
733
941.8049039840698
734
956.2523045539856
735
964.1305155754089
736
951.116747379303
737
970.4424622058868
738
969.8927714824677
739
908.8055398464203
740
923.3424937725067
741
958.7469000816345
742
844.68971991539
743
934.8668341636658
744
982.0291905403137
745
933.6894226074219
746
948.8078012466431
747
983.8365197181702
748
952.9836332798004
749
918.7284321784973
750
879.4029247760773
751
908.1159272193909
752
935.9426462650299
753
955.3384485244751
754
926.2899475097656
755
925.0909354686737
756
902.7947380542755
757
865.4047863483429
758
943.7688307762146
759
917.4010903835297
760
884.1050319671631
761
920.0725615024567
762
903.6849255561829
763
896.5852313041687
764
921.2530832290649
765
864.4606137275696
766
879.6069300174713
767
911.0260424613953
768
882.0447161197662
769
896.287544965744
770
886.7879655361176
771
911.4577505588531
772
937.6248860359192
773
912.6373138427734
774
886.7398335933685
775
857.0496799945831
776
905.9841997623444
777
916.212206

1102
679.388299703598
1103
680.8998391628265
1104
690.6231560707092
1105
667.2843134403229
1106
663.1074833869934
1107
710.7529773712158
1108
696.6740326881409
1109
652.7482371330261
1110
673.4921345710754
1111
649.5063753128052
1112
669.5928475856781
1113
664.0601170063019
1114
683.6793513298035
1115
626.889890909195
1116
651.3614675998688
1117
680.5935769081116
1118
662.3854608535767
1119
643.4492492675781
1120
669.4611940383911
1121
726.221387386322
1122
674.0479121208191
1123
638.1029918193817
1124
664.8928561210632
1125
688.6037693023682
1126
683.0784363746643
1127
624.8975958824158
1128
656.016881942749
1129
602.2455370426178
1130
685.116973400116
1131
674.8340537548065
1132
652.780611038208
1133
684.2149803638458
1134
625.3659348487854
1135
672.3050968647003
1136
641.2891039848328
1137
642.9125912189484
1138
662.2505068778992
1139
682.5104796886444
1140
664.5774583816528
1141
643.8059532642365
1142
637.4512741565704
1143
655.359591960907
1144
637.1094205379486
1145
650.695635557

1461
503.99127984046936
1462
536.9914011955261
1463
539.6872897148132
1464
525.9097745418549
1465
502.0338559150696
1466
557.5256993770599
1467
529.1989936828613
1468
523.0525743961334
1469
552.9487628936768
1470
506.0244052410126
1471
542.112592458725
1472
530.4572253227234
1473
537.3942096233368
1474
560.6250383853912
1475
547.0247015953064
1476
527.2435727119446
1477
515.239200592041
1478
545.7151968479156
1479
526.0159320831299
1480
503.22186374664307
1481
547.2240264415741
1482
539.8582410812378
1483
518.6103613376617
1484
503.8643937110901
1485
558.155645608902
1486
521.3367395401001
1487
547.9911584854126
1488
518.0822098255157
1489
505.05189085006714
1490
523.4984636306763
1491
506.0480856895447
1492
502.04845237731934
1493
532.4721505641937
1494
511.55933904647827
1495
497.2016370296478
1496
557.2711887359619
1497
529.9388117790222
1498
505.4675030708313
1499
514.0799021720886
1500
487.35000491142273
1501
525.271416425705
1502
535.9652433395386
1503
510.90174317359924
1504
500

1814
458.1050555706024
1815
432.5638484954834
1816
454.33856987953186
1817
473.59351539611816
1818
474.60285115242004
1819
467.60733437538147
1820
473.3539264202118
1821
427.3813388347626
1822
444.98025155067444
1823
445.28797125816345
1824
462.4619379043579
1825
420.82391703128815
1826
452.43648982048035
1827
457.84636211395264
1828
447.79153966903687
1829
449.1206452846527
1830
452.15911650657654
1831
486.880499958992
1832
462.22623109817505
1833
417.72311305999756
1834
409.5534029006958
1835
425.83845257759094
1836
430.8369565010071
1837
441.69996786117554
1838
427.3367609977722
1839
444.13799226284027
1840
418.265908241272
1841
443.9418611526489
1842
442.913272023201
1843
448.7522597312927
1844
427.66385674476624
1845
419.26442790031433
1846
441.09603476524353
1847
443.5436511039734
1848
449.26511216163635
1849
457.5060420036316
1850
424.1000003814697
1851
452.27210879325867
1852
410.94180166721344
1853
393.0573809146881
1854
441.0133545398712
1855
423.10193860530853
1856
425.49791

2165
385.30276226997375
2166
378.80985713005066
2167
379.6111731529236
2168
349.1780308485031
2169
393.62943267822266
2170
394.7394208908081
2171
386.8988803625107
2172
363.7231867313385
2173
361.874963760376
2174
378.74200081825256
2175
361.95626997947693
2176
360.3348698616028
2177
358.6955726146698
2178
383.61741757392883
2179
363.02609848976135
2180
378.58720684051514
2181
387.29106736183167
2182
402.6522195339203
2183
372.02539253234863
2184
372.144273519516
2185
398.23272466659546
2186
370.22145080566406
2187
374.66964840888977
2188
374.83174204826355
2189
369.345671415329
2190
354.4960563182831
2191
368.1441149711609
2192
361.89513778686523
2193
362.726655125618
2194
391.81089782714844
2195
378.3383882045746
2196
381.9083478450775
2197
369.56525337696075
2198
375.7949450016022
2199
367.80853271484375
2200
387.99761176109314
2201
399.74923956394196
2202
371.1580755710602
2203
379.95675444602966
2204
398.6858870983124
2205
362.94014596939087
2206
379.0633696317673
2207
375.8084659

2516
326.4294419288635
2517
325.7963333129883
2518
327.0170907974243
2519
350.00822603702545
2520
313.5047627687454
2521
350.64223635196686
2522
331.18041491508484
2523
358.61805963516235
2524
319.6674828529358
2525
332.5332601070404
2526
319.93474531173706
2527
330.5959527492523
2528
326.47203636169434
2529
320.29482436180115
2530
337.50591468811035
2531
320.42144775390625
2532
340.570219039917
2533
338.8854316473007
2534
314.2381112575531
2535
341.2932940721512
2536
336.68772292137146
2537
326.35246443748474
2538
326.2900433540344
2539
322.0106110572815
2540
343.40606689453125
2541
331.8502435684204
2542
346.77618420124054
2543
345.7140169143677
2544
324.21297788619995
2545
344.6523187160492
2546
332.50428438186646
2547
332.2785439491272
2548
326.87750816345215
2549
349.40299701690674
2550
345.7994706630707
2551
319.23266649246216
2552
332.91435492038727
2553
352.7888355255127
2554
325.70838379859924
2555
328.3249137401581
2556
313.39913749694824
2557
318.3709490299225
2558
326.78945

2866
310.6072951555252
2867
310.0711649656296
2868
303.84035301208496
2869
290.92761278152466
2870
308.4215614795685
2871
297.96153235435486
2872
270.79011631011963
2873
297.50292241573334
2874
291.23371386528015
2875
310.1634554862976
2876
304.8243713378906
2877
305.99075651168823
2878
305.2828356027603
2879
314.2230694293976
2880
315.12665843963623
2881
346.2891185283661
2882
298.7719646692276
2883
295.8814845085144
2884
321.72625732421875
2885
284.886944770813
2886
320.91077959537506
2887
304.1366398334503
2888
288.3958303928375
2889
291.0185852050781
2890
310.3022015094757
2891
283.71202313899994
2892
282.00420594215393
2893
327.93822026252747
2894
320.28259205818176
2895
298.22482442855835
2896
286.46455001831055
2897
304.9577796459198
2898
307.32223534584045
2899
319.23275780677795
2900
277.2018382549286
2901
300.46785163879395
2902
297.5386265516281
2903
305.36464738845825
2904
291.39654886722565
2905
289.4493398666382
2906
301.8134653568268
2907
288.8937042951584
2908
318.44903

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
