In [6]:
# coding:utf-8
import numpy as np
import tensorflow as tf
import os
import time
import datetime
import ctypes

ll = ctypes.cdll.LoadLibrary
lib = ll("./init.so")


class Config(object):

    def __init__(self):
        self.L1_flag = True
        self.hidden_size = 50
        self.nbatches = 100
        self.entity = 0
        self.relation = 0
        self.trainTimes = 3000
        self.margin = 1.0


class TransEModel(object):

    def __init__(self, config):

        entity_total = config.entity
        relation_total = config.relation
        batch_size = config.batch_size
        size = config.hidden_size
        margin = config.margin

        self.pos_h = tf.placeholder(tf.int32, [None])
        self.pos_t = tf.placeholder(tf.int32, [None])
        self.pos_r = tf.placeholder(tf.int32, [None])

        self.neg_h = tf.placeholder(tf.int32, [None])
        self.neg_t = tf.placeholder(tf.int32, [None])
        self.neg_r = tf.placeholder(tf.int32, [None])

        with tf.name_scope("embedding"):
            self.ent_embeddings = tf.get_variable(name="ent_embedding", shape=[
                                                  entity_total, size], initializer=tf.contrib.layers.xavier_initializer(uniform=False))
            self.rel_embeddings = tf.get_variable(name="rel_embedding", shape=[
                                                  relation_total, size], initializer=tf.contrib.layers.xavier_initializer(uniform=False))
            pos_h_e = tf.nn.embedding_lookup(self.ent_embeddings, self.pos_h)
            pos_t_e = tf.nn.embedding_lookup(self.ent_embeddings, self.pos_t)
            pos_r_e = tf.nn.embedding_lookup(self.rel_embeddings, self.pos_r)
            neg_h_e = tf.nn.embedding_lookup(self.ent_embeddings, self.neg_h)
            
        if config.L1_flag:  # use L1-norm
            # TO DO
            pos_e = tf.norm((pos_h_e + pos_r_e - pos_t_e), ord=1, axis=1, keepdims=None)
            neg_e = tf.norm((neg_h_e + neg_r_e - neg_t_e), ord=1, axis=1, keepdims=None)
            
        else:  # use L2-norm
            # TO DO
            pos_e = tf.norm((pos_h_e + pos_r_e - pos_t_e), ord=2, axis=1, keepdims=None)
            neg_e = tf.norm((neg_h_e + neg_r_e - neg_t_e), ord=2, axis=1, keepdims=None)
            
        with tf.name_scope("output"):
            # TO DO
            self.loss = tf.reduce_sum(tf.math.maximum(0.0, (margin + pos_e - neg_e)))
            
            
def main(_):

    lib.init()
    config = Config()
    config.relation = lib.getRelationTotal()
    config.entity = lib.getEntityTotal()
    config.batch_size = lib.getTripleTotal() // config.nbatches

    with tf.Graph().as_default():
        config_gpu = tf.ConfigProto()
        config_gpu.gpu_options.allow_growth = True
        config_gpu.gpu_options.per_process_gpu_memory_fraction = 0.15
        sess = tf.Session(config=config_gpu)
        with sess.as_default():
            initializer = tf.contrib.layers.xavier_initializer(uniform=False)
            with tf.variable_scope("model", reuse=None, initializer=initializer):
                trainModel = TransEModel(config=config)

            global_step = tf.Variable(0, name="global_step", trainable=False)
            optimizer = tf.train.GradientDescentOptimizer(0.001)
            grads_and_vars = optimizer.compute_gradients(trainModel.loss)
            train_op = optimizer.apply_gradients(
                grads_and_vars, global_step=global_step)
            saver = tf.train.Saver()
            sess.run(tf.initialize_all_variables())

            def train_step(pos_h_batch, pos_t_batch, pos_r_batch, neg_h_batch, neg_t_batch, neg_r_batch):
                feed_dict = {
                    trainModel.pos_h: pos_h_batch,
                    trainModel.pos_t: pos_t_batch,
                    trainModel.pos_r: pos_r_batch,
                    trainModel.neg_h: neg_h_batch,
                    trainModel.neg_t: neg_t_batch,
                    trainModel.neg_r: neg_r_batch
                }
                _, step, loss = sess.run(
                    [train_op, global_step, trainModel.loss], feed_dict)
                return loss

            ph = np.zeros(config.batch_size, dtype=np.int32)
            pt = np.zeros(config.batch_size, dtype=np.int32)
            pr = np.zeros(config.batch_size, dtype=np.int32)
            nh = np.zeros(config.batch_size, dtype=np.int32)
            nt = np.zeros(config.batch_size, dtype=np.int32)
            nr = np.zeros(config.batch_size, dtype=np.int32)

            ph_addr = ph.__array_interface__['data'][0]
            pt_addr = pt.__array_interface__['data'][0]
            pr_addr = pr.__array_interface__['data'][0]
            nh_addr = nh.__array_interface__['data'][0]
            nt_addr = nt.__array_interface__['data'][0]
            nr_addr = nr.__array_interface__['data'][0]

            for times in range(config.trainTimes):
                res = 0.0
                for batch in range(config.nbatches):
                    lib.getBatch(ph_addr, pt_addr, pr_addr, nh_addr,
                                 nt_addr, nr_addr, config.batch_size)
                    res += train_step(ph, pt, pr, nh, nt, nr)
                    current_step = tf.train.global_step(sess, global_step)
                print(times)
                print(res)
            #saver.save(sess, 'model.vec')
            # save the embeddings
            f = open("entity2vec.txt", "w")
            enb = trainModel.ent_embeddings.eval()
            for i in enb:
                for j in i:
                    f.write("%f\t" % (j))
                f.write("\n")
            f.close()

            f = open("relation2vec.txt", "w")
            enb = trainModel.rel_embeddings.eval()
            for i in enb:
                for j in i:
                    f.write("%f\t" % (j))
                f.write("\n")
            f.close()

if __name__ == "__main__":
    tf.app.run()


0
152832.54516601562
1
111016.54931640625
2
94944.21826171875
3
84214.93981933594
4
76162.6323852539
5
68993.54681396484
6
62727.624450683594
7
57265.40603637695
8
52777.14208984375
9
48085.99621582031
10
43849.87612915039
11
39964.414794921875
12
37176.940368652344
13
33594.669036865234
14
31702.248168945312
15
30151.248046875
16
27884.185607910156
17
25892.03370666504
18
23620.275756835938
19
22759.00291442871
20
21296.858917236328
21
20345.960159301758
22
19047.152420043945
23
18236.964645385742
24
16818.522354125977
25
15809.563789367676
26
15183.274932861328
27
13995.792289733887
28
13809.455947875977
29
13261.789054870605
30
12879.124778747559
31
12668.335655212402
32
11747.771591186523
33
11927.75439453125
34
11255.27239227295
35
10538.22550201416
36
10268.011024475098
37
9871.856704711914
38
9351.706787109375
39
8842.538131713867
40
8983.684844970703
41
8628.940719604492
42
8863.996891021729
43
8039.460952758789
44
7983.9202880859375
45
8479.92322921753
46
7731.883056640625
47


370
1716.072434425354
371
1725.705750465393
372
1715.8866300582886
373
1639.6315245628357
374
1589.3677797317505
375
1680.27845287323
376
1720.8303365707397
377
1598.5370831489563
378
1604.2494945526123
379
1640.277376651764
380
1599.9993624687195
381
1598.4267110824585
382
1649.3431797027588
383
1629.530758857727
384
1608.0235567092896
385
1566.4620475769043
386
1653.8294162750244
387
1610.0934448242188
388
1624.5157570838928
389
1564.5720524787903
390
1577.929202079773
391
1617.6620073318481
392
1507.2616987228394
393
1617.0164461135864
394
1573.3209981918335
395
1568.71102809906
396
1557.3366260528564
397
1591.8749284744263
398
1567.8344330787659
399
1580.187201499939
400
1547.0074462890625
401
1593.8403987884521
402
1579.4148454666138
403
1576.586365222931
404
1605.1101379394531
405
1598.933801651001
406
1569.3021836280823
407
1592.1273398399353
408
1558.6355481147766
409
1553.5723390579224
410
1505.4446954727173
411
1555.2584409713745
412
1551.917338848114
413
1556.291633605957
41

732
986.9138798713684
733
948.9096598625183
734
1004.750123500824
735
938.1286337375641
736
919.9436721801758
737
1006.1320390701294
738
966.853945016861
739
999.2246074676514
740
941.547486782074
741
997.9053258895874
742
934.1119334697723
743
997.9422945976257
744
938.4624743461609
745
941.6578183174133
746
991.7857041358948
747
913.2446253299713
748
966.8487277030945
749
902.6456179618835
750
967.2096571922302
751
898.4560186862946
752
956.1159462928772
753
895.8729212284088
754
986.6867680549622
755
934.5539929866791
756
917.0196974277496
757
928.1924710273743
758
944.178936958313
759
898.0931260585785
760
985.8273615837097
761
935.7124032974243
762
908.0339312553406
763
922.0408682823181
764
956.3365869522095
765
917.335330247879
766
917.9143044948578
767
962.9656507968903
768
895.6851131916046
769
948.3434076309204
770
873.2810168266296
771
929.1003305912018
772
912.8669793605804
773
936.9248778820038
774
949.5708899497986
775
944.1617398262024
776
938.7809915542603
777
879.47660

1103
704.3520715236664
1104
659.9465274810791
1105
682.0750217437744
1106
657.3648343086243
1107
649.5806927680969
1108
678.8612189292908
1109
687.582781791687
1110
715.8394088745117
1111
702.964652299881
1112
691.7089602947235
1113
670.5263526439667
1114
695.980304479599
1115
620.3411548137665
1116
674.7077441215515
1117
680.1189076900482
1118
659.8161022663116
1119
676.4764358997345
1120
633.5557100772858
1121
687.3416278362274
1122
681.045871257782
1123
646.0316569805145
1124
683.2949407100677
1125
644.1334891319275
1126
697.2370092868805
1127
717.9676430225372
1128
621.302973985672
1129
669.0671918392181
1130
698.8348400592804
1131
661.5217754840851
1132
661.1929175853729
1133
672.9327380657196
1134
677.9727208614349
1135
692.2257778644562
1136
653.3470914363861
1137
607.0957789421082
1138
665.4294376373291
1139
663.8593188524246
1140
658.651852607727
1141
670.6793682575226
1142
659.8571383953094
1143
672.2710573673248
1144
624.9461116790771
1145
648.8009045124054
1146
663.72155976

1461
532.0756847858429
1462
510.48772871494293
1463
543.66392827034
1464
543.4930386543274
1465
522.0389087200165
1466
520.5607576370239
1467
536.5313239097595
1468
543.1850790977478
1469
564.2598340511322
1470
536.4216115474701
1471
495.4503695964813
1472
557.9102435112
1473
513.0160639286041
1474
560.2981860637665
1475
491.44689679145813
1476
527.3812980651855
1477
557.3127913475037
1478
530.9866209030151
1479
523.7974965572357
1480
488.4669382572174
1481
533.8396233320236
1482
541.2996499538422
1483
532.0841139554977
1484
537.5927903652191
1485
490.1772359609604
1486
525.5566198825836
1487
508.11076951026917
1488
525.0831602811813
1489
534.3485445976257
1490
520.4320755004883
1491
549.943033695221
1492
567.2292139530182
1493
542.850795507431
1494
534.5159027576447
1495
519.5873763561249
1496
538.455260515213
1497
536.824031829834
1498
538.1660966873169
1499
511.84038281440735
1500
558.1594967842102
1501
526.1148390769958
1502
541.3979252576828
1503
520.6142203807831
1504
544.1252372

1815
483.3136639595032
1816
461.2183768749237
1817
438.5616924762726
1818
458.5146806240082
1819
435.4230670928955
1820
408.00919806957245
1821
418.84295654296875
1822
436.88251185417175
1823
419.9050979614258
1824
431.6488275527954
1825
390.63740360736847
1826
428.0263057947159
1827
452.6233766078949
1828
466.5500829219818
1829
435.2249287366867
1830
439.6249040365219
1831
417.78439915180206
1832
423.2965922355652
1833
443.13307762145996
1834
412.27057003974915
1835
454.2572581768036
1836
435.4132122993469
1837
436.35832464694977
1838
445.0008702278137
1839
432.7276268005371
1840
446.00888180732727
1841
453.6281976699829
1842
476.04029273986816
1843
410.3186180591583
1844
427.0834159851074
1845
415.0826361179352
1846
448.8005862236023
1847
410.9365712404251
1848
425.0567526817322
1849
418.55329036712646
1850
431.49317383766174
1851
416.3522832393646
1852
467.7189917564392
1853
440.50202548503876
1854
411.2214734554291
1855
479.3506808280945
1856
438.9662160873413
1857
455.859970808029

2166
374.5798661708832
2167
382.92096757888794
2168
392.8405315876007
2169
381.344708442688
2170
381.3363347053528
2171
390.3847665786743
2172
365.5834991931915
2173
364.4446518421173
2174
370.37565994262695
2175
369.3776453733444
2176
406.76433396339417
2177
394.92744731903076
2178
376.6397898197174
2179
366.4682674407959
2180
377.6661524772644
2181
385.8137540817261
2182
398.50591015815735
2183
382.1039967536926
2184
346.12546968460083
2185
356.8486531972885
2186
371.0950927734375
2187
343.1024520397186
2188
373.76882457733154
2189
427.526971578598
2190
387.5142240524292
2191
387.5361545085907
2192
386.30883383750916
2193
389.3003406524658
2194
380.25556230545044
2195
371.0393490791321
2196
416.252637386322
2197
374.57933259010315
2198
368.2151310443878
2199
415.95910263061523
2200
365.4691708087921
2201
359.37506437301636
2202
367.9313642978668
2203
351.59616136550903
2204
381.420218706131
2205
381.00665831565857
2206
344.5794560909271
2207
365.3184039592743
2208
340.18110275268555


2517
330.18787121772766
2518
343.0966491699219
2519
332.098868727684
2520
338.61309134960175
2521
328.2711635828018
2522
322.8436839580536
2523
355.76059663295746
2524
342.89653193950653
2525
340.4399914741516
2526
314.3860778808594
2527
352.6178261041641
2528
327.5290699005127
2529
322.2153648138046
2530
316.9627695083618
2531
297.9218165874481
2532
301.57122004032135
2533
325.0916097164154
2534
315.0316421985626
2535
344.85397696495056
2536
332.6824929714203
2537
340.8425347805023
2538
319.6492967605591
2539
358.10878896713257
2540
339.70447158813477
2541
311.2888433933258
2542
312.02228331565857
2543
325.96773886680603
2544
349.46124482154846
2545
301.0046122074127
2546
315.5485484600067
2547
357.44156098365784
2548
320.04020833969116
2549
347.32751870155334
2550
308.19398760795593
2551
296.6023621559143
2552
310.56139266490936
2553
324.10522508621216
2554
350.0108230113983
2555
338.60780477523804
2556
355.34197652339935
2557
352.16967487335205
2558
311.1917173862457
2559
352.399968

2868
311.77322244644165
2869
298.73140954971313
2870
314.78833985328674
2871
302.8057942390442
2872
271.68712544441223
2873
284.74906826019287
2874
304.9540106058121
2875
297.8690232038498
2876
288.7253794670105
2877
294.8435287475586
2878
324.1924293041229
2879
312.38576793670654
2880
279.0812964439392
2881
300.4603922367096
2882
292.17583072185516
2883
300.0403935909271
2884
291.70130467414856
2885
304.3513948917389
2886
320.3005335330963
2887
303.6813452243805
2888
266.8559910058975
2889
285.436688542366
2890
304.1631338596344
2891
298.38514518737793
2892
273.4751808643341
2893
303.12067675590515
2894
291.1613256931305
2895
302.28665351867676
2896
316.56826543807983
2897
291.3100210428238
2898
294.6157076358795
2899
286.45563554763794
2900
314.6660438776016
2901
310.57271909713745
2902
291.36342215538025
2903
296.6475855112076
2904
315.04341220855713
2905
297.4324219226837
2906
287.92250394821167
2907
286.3850429058075
2908
290.0815634727478
2909
294.081827878952
2910
281.7158931493

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
