In [2]:
import numpy as np
import tensorflow as tf
import json
import os
import pickle
import sys
nb_dir = os.path.split(os.path.split(os.getcwd())[0])[0]
if nb_dir not in sys.path:
    sys.path.append(nb_dir)
    sys.path.append(nb_dir + '/path_rnn')
import pandas as pd
from parsing.special_tokens import *
from path_rnn.embeddings import load_embedding_matrix
from path_rnn.batch_generator import BatchGenerator
from util import get_target_relations_vocab
from tensor_generator import get_indexed_paths, get_labels
from path_rnn.embeddings import Word2VecEmbeddings, RandomEmbeddings, EntityTypeEmbeddings
pd.set_option('display.max_colwidth', -1)

In [3]:
train_dataset = pd.read_json(
        '../../data/sentwise=F_cutoff=4_limit=100_method=shortest_tokenizer=punkt_medhop_train.json')
train_document_store = pickle.load(open('../../data/train_doc_store_punkt.pickle',
                                        'rb'))

In [7]:
tf.reset_default_graph()
relation_token_embeddings = Word2VecEmbeddings('./../../medhop_word2vec_punkt',
                                              name='token_embd',
                                              unk_token=UNK,
                                              trainable=False,
                                              special_tokens=[(ENT_1, False),
                                                              (ENT_2, False),
                                                              (ENT_X, False),
                                                              (UNK, False),
                                                              (END, False),
                                                              (PAD, True)])
entity_embeddings = relation_token_embeddings
target_embeddings = RandomEmbeddings(train_dataset['relation'],
                                     name='target_rel_emb',
                                     embedding_size=150,
                                     unk_token=None,
                                     initializer=tf.truncated_normal_initializer(mean=0.0,
                                                                                 stddev=1.0,
                                                                                 dtype=tf.float64))
entity_type_embeddings = EntityTypeEmbeddings('./../../parsing/entity_map.txt',
                                              name='entity_type_emb',
                                              embedding_size=150,
                                              unk_token=None,
                                              initializer=tf.truncated_normal_initializer(mean=0.0,
                                                                                          stddev=1.0,
                                                                                          dtype=tf.float64))

In [12]:
print(entity_type_embeddings.get_idx('DB00294'), entity_type_embeddings.get_idx('P35367'))

0 1


In [17]:
(indexed_relation_paths,
 indexed_entity_paths, 
 indexed_target_relations,
 path_partitions,
 path_lengths,
 num_words) = get_indexed_paths(q_relation_paths=train_dataset['relation_paths'],
                  q_entity_paths=train_dataset['entity_paths'],
                  target_relations=train_dataset['relation'],
                  document_store=train_document_store,
                  relation_token_embeddings=relation_token_embeddings,
                  entity_embeddings=entity_embeddings,
                  target_relation_embeddings=target_embeddings,
                  max_path_length=5,
                  max_relation_length=train_document_store.max_tokens,
                  truncate_doc=True)

Total paths: 33370


In [18]:
labels = get_labels(train_dataset['label'])

In [19]:
path_partitions[:10]

array([[ 0,  1],
       [ 1,  4],
       [ 4,  5],
       [ 5, 11],
       [11, 17],
       [17, 19],
       [19, 20],
       [20, 23],
       [23, 26],
       [26, 28]])

In [20]:
path_lengths[:10]

array([3., 5., 5., 5., 4., 5., 5., 5., 5., 5.])

In [21]:
train_dataset[:10]

Unnamed: 0,id,source,target,relation,entity_paths,relation_paths,label
0,MH_train_0,DB00072,DB00773,interacts_with,"[[DB00072, P16104, DB00773]]","[[2822, 13427, -1]]",1
1,MH_train_0,DB00294,DB00773,interacts_with,"[[DB00294, P06401, P12830, P16104, DB00773], [DB00294, P06401, P12830, P15692, DB00773], [DB00294, P06401, P12830, Q13315, DB00773]]","[[12161, 5896, 11862, 13427, -1], [12161, 5896, 3003, 8080, -1], [12161, 5896, 12119, 537, -1]]",0
2,MH_train_0,DB00338,DB00773,interacts_with,"[[DB00338, O75030, DB00133, DB00773]]","[[3059, 11424, 537, -1]]",0
3,MH_train_0,DB00341,DB00773,interacts_with,"[[DB00341, P35367, P05231, P16104, DB00773], [DB00341, P35367, P05231, P15692, DB00773], [DB00341, P35367, P01375, P15692, DB00773], [DB00341, P35367, P05231, P35228, DB00773], [DB00341, P35367, P05231, Q13315, DB00773], [DB00341, P35367, P01375, P43351, DB00773]]","[[9869, 7118, 11664, 13427, -1], [9869, 7118, 11285, 8080, -1], [9869, 7118, 2873, 8080, -1], [9869, 7118, 11664, 13427, -1], [9869, 7118, 11664, 537, -1], [9869, 7118, 2438, 848, -1]]",0
4,MH_train_0,DB00588,DB00773,interacts_with,"[[DB00588, P05113, P15692, DB00773], [DB00588, P05112, P15692, DB00773], [DB00588, P01375, P15692, DB00773], [DB00588, P05112, P16104, DB00773], [DB00588, P05112, DB00133, DB00773], [DB00588, P01375, P43351, DB00773]]","[[6139, 11285, 8080, -1], [3161, 11285, 8080, -1], [3161, 2873, 8080, -1], [3161, 347, 13427, -1], [3161, 347, 537, -1], [3161, 2438, 848, -1]]",0
5,MH_train_0,DB00820,DB00773,interacts_with,"[[DB00820, P21554, P15692, DB00773], [DB00820, P29474, P16104, DB00773]]","[[1953, 1512, 8080, -1], [1003, 43, 13427, -1]]",0
6,MH_train_0,DB02546,DB00773,interacts_with,"[[DB02546, P16104, DB00773]]","[[11837, 13427, -1]]",0
7,MH_train_0,DB02901,DB00773,interacts_with,"[[DB02901, P10275, P16104, DB00773], [DB02901, P10275, P15692, DB00773], [DB02901, P05093, P43351, DB00773]]","[[1803, 9456, 13427, -1], [1803, 9454, 8080, -1], [1803, 2438, 848, -1]]",0
8,MH_train_0,DB04844,DB00773,interacts_with,"[[DB04844, Q05940, P01308, P43351, DB00773], [DB04844, Q05940, P01308, P15692, DB00773], [DB04844, Q05940, P37840, P16104, DB00773]]","[[12088, 8550, 6964, 848, -1], [12088, 8550, 11276, 8080, -1], [12088, 9915, 10620, 13427, -1]]",0
9,MH_train_1,DB06822,DB09079,interacts_with,"[[DB06822, P48061, P15692, DB09079], [DB06822, P61073, P15692, DB09079]]","[[8149, 8092, 4068, -1], [8149, 8092, 4068, -1]]",1


In [22]:
path_lengths.shape

(33370,)

In [23]:
path_lengths[:10]

array([3., 5., 5., 5., 4., 5., 5., 5., 5., 5.])

In [24]:
num_words[0]

array([231, 151,   1,   1,   1])

In [25]:
len(train_document_store.documents[13427])

361

In [26]:
indexed_relation_paths[0][0][:260]

array([26450,    66,  6048,    57,    26,  3679,     2,  1754,     2,
          31,  1001,     2,  2679,  3635,     5,  1104,     6, 26452,
           2,   225,   691,    46,    16,  3196,    26,   322,   745,
           0,  2401,    23,    49,   194,     4,    57,     3,   897,
           1,    30,    93,     2, 26452,  2020,   445,   223,     1,
          26,   150,  3635,    31,    12,  3679,     2,  1754,     1,
          10,  4303,     2,    21,    75,     1,     6,   691,    46,
          24,   215,   225,    28,   392,    15, 26452,    22,     0,
        2235,   640,   214,    23,  4528,     2, 14271,     7, 26452,
         138,     8,     5, 26452,     7, 26452,     2,     8,    16,
          19,   898,     9,  3679,     2,  1754,     6,     4,   266,
          28,   682,     3,   897,     0, 26450,    14,  1863,   325,
        1296,     9,    28,    79,  3679,     2,  1754,     0,  1654,
           3,   897,    26,     4,   244,     3,   502,     2, 26451,
           1,    10,

In [27]:
indexed_relation_paths[0][2][:10]

array([26454, 26455, 26455, 26455, 26455, 26455, 26455, 26455, 26455,
       26455])

In [28]:
train_document_store.get_document(2822, 'DB00072', 'P16104')

['ent_1',
 'has',
 'opposing',
 'effects',
 'on',
 'SN',
 '-',
 '38',
 '-',
 'induced',
 'double',
 '-',
 'strand',
 'breaks',
 'and',
 'cytotoxicity',
 'in',
 'ent_x',
 '-',
 'positive',
 'gastric',
 'cancer',
 'cells',
 'depending',
 'on',
 'administration',
 'sequence',
 '.',
 'AIM',
 ':',
 'We',
 'investigated',
 'the',
 'effects',
 'of',
 'trastuzumab',
 ',',
 'an',
 'anti',
 '-',
 'ent_x',
 'humanized',
 'monoclonal',
 'antibody',
 ',',
 'on',
 'DNA',
 'breaks',
 'induced',
 'by',
 'SN',
 '-',
 '38',
 ',',
 'a',
 'topoisomerase',
 '-',
 '1',
 'inhibitor',
 ',',
 'in',
 'gastric',
 'cancer',
 'cell',
 'lines',
 'positive',
 'or',
 'negative',
 'for',
 'ent_x',
 'expression',
 '.',
 'MATERIALS',
 'AND',
 'METHODS',
 ':',
 'NCI',
 '-',
 'N87',
 '(',
 'ent_x',
 '+',
 ')',
 'and',
 'ent_x',
 '(',
 'ent_x',
 '-',
 ')',
 'cells',
 'were',
 'exposed',
 'to',
 'SN',
 '-',
 '38',
 'in',
 'the',
 'presence',
 'or',
 'absence',
 'of',
 'trastuzumab',
 '.',
 'ent_1',
 'was',
 'added',
 'eithe

In [29]:
relation_token_embeddings.get_idx('ent_1')

26450

In [30]:
relation_token_embeddings.get_idx('has')

66

In [31]:
relation_token_embeddings.get_idx('opposing')

6048

In [32]:
relation_token_embeddings.get_idx('#END')

26454

In [33]:
path_lengths[:10]

array([3., 5., 5., 5., 4., 5., 5., 5., 5., 5.])

In [34]:
indexed_entity_paths[:10]

array([[ 3104,  2251, 12301, 26455, 26455],
       [26453,  5438,  1274,  2251, 12301],
       [26453,  5438,  1274,   182, 12301],
       [26453,  5438,  1274,  1773, 12301],
       [25028,  5679,  1838, 12301, 26455],
       [26453,  2213,   299,  2251, 12301],
       [26453,  2213,   299,   182, 12301],
       [26453,  2213,    86,   182, 12301],
       [26453,  2213,   299,  1131, 12301],
       [26453,  2213,   299,  1773, 12301]])

In [36]:
print(relation_token_embeddings.get_idx('DB00072'), 
      relation_token_embeddings.get_idx('P16104'), 
      relation_token_embeddings.get_idx('DB00773'))

3104 2251 12301


## Batch generator validation

In [49]:
batch_generator =  BatchGenerator(indexed_relation_paths,
                                  indexed_entity_paths,
                                  indexed_target_relations,
                                  path_partitions,
                                  path_lengths,
                                  num_words,
                                  labels,
                                  batch_size=5,
                                  test_prop=0.1,
                                  train_eval_prop=0.1)

Dataset size: 14436
Train queries: 12992 in 2599 batches,
Test queries: 1444 in 1 batches,
Train evaluation queries: 1299 in 1 batches



In [50]:
batch_generator.idx_generators[BatchGenerator.TRAIN_EVAL].idxs

array([ 5313, 10613,  1359, ..., 11203,  1002, 10704])

In [51]:
batch_generator.train_size

12992

In [52]:
batch_generator.path_partitions[batch_generator.idx_generators[BatchGenerator.TRAIN_EVAL].idxs]

array([[13749, 13751],
       [27867, 27868],
       [ 3325,  3327],
       ...,
       [28777, 28779],
       [ 2507,  2508],
       [27998, 27999]])

In [53]:
batch_generator._generate_partition(batch_generator.path_partitions[batch_generator.idx_generators[BatchGenerator.TRAIN_EVAL].idxs])

array([[0.000e+00, 2.000e+00],
       [2.000e+00, 3.000e+00],
       [3.000e+00, 5.000e+00],
       ...,
       [3.108e+03, 3.110e+03],
       [3.110e+03, 3.111e+03],
       [3.111e+03, 3.112e+03]])

In [54]:
while batch_generator.idx_generators[BatchGenerator.TRAIN].epochs_completed < 2:
    batch_generator.get_batch(BatchGenerator.TRAIN, debug=True)

0 5 [10717 11738  1051  2938 10970]
5 10 [5287 3903 6935 7894 9311]
10 15 [ 7822  7604  9430 12595   456]
15 20 [10405  3987  2231 11603  9593]
20 25 [ 295 3700 5920  458 4685]
25 30 [ 1780  5112  9539  3044 10272]
30 35 [12991  1068  1350 12177 12923]
35 40 [3823 4143 2373 5926 6665]
40 45 [2460 7947 9714 7787 8115]
45 50 [4663 9238 2012 8739 5365]
50 55 [10527 12784  7064  8699  1854]
55 60 [ 1738 11040   623  9928  7055]
60 65 [11796  3541  3164  7269   659]
65 70 [ 6013 11080  3934   459 12849]
70 75 [ 8584  7303 10400  7924 11498]
75 80 [ 6906 10198  1533  1245  8152]
80 85 [ 1449  8214  8066  4379 11316]
85 90 [ 3726   336 11587  9450 10620]
90 95 [11578 11267 12004  5632  6746]
95 100 [  383  6738  4424 10016  2541]
100 105 [ 8200  4509  1939  5304 12286]
105 110 [12341  9462  1617   381 10381]
110 115 [ 4256   618 11905 11780  9546]
115 120 [1018 4753 6938 5676 7389]
120 125 [ 1506  5081  8632  1589 10787]
125 130 [10183 12336 10286  5106  7693]
130 135 [5746 2616 6392 7274 838

1645 1650 [10646 10605  2415   390  1301]
1650 1655 [2285  661 2423 4903 4849]
1655 1660 [10508  6300 11496  4039  8767]
1660 1665 [7176 7995 6941 9264 3930]
1665 1670 [ 9045   583  3854 12982 12290]
1670 1675 [3843 8157 3950  288 9651]
1675 1680 [4943 4054 6109 7681 1832]
1680 1685 [10696 10359 12241  7031 10202]
1685 1690 [1604 9704 8151 9271 1660]
1690 1695 [ 8907  3497  4142 11068 10321]
1695 1700 [5998 3531 7724  616 1056]
1700 1705 [10141  4239  5067  4747   642]
1705 1710 [  259  8190 10434  1978 10763]
1710 1715 [1539  477 2852 6108 3326]
1715 1720 [5404  949 1632 6017 9408]
1720 1725 [ 9100 12431 12572 12234  3604]
1725 1730 [4197 3438 7925 9169 9437]
1730 1735 [ 6273  8386 12805 10043 11711]
1735 1740 [ 9296  1248  6950  2027 10790]
1740 1745 [6316 4349 7581 8452 1950]
1745 1750 [  999 11427 11575  9473  6617]
1750 1755 [5927 2595 2148 9284 2691]
1755 1760 [ 6603  1936  5570  5903 12269]
1760 1765 [ 7247   706   676  3399 11171]
1765 1770 [2822 9979 7769 7738 9971]
1770 1775 

3170 3175 [   62  3041   291 11696  6360]
3175 3180 [ 8053 11210  2506  2235 10560]
3180 3185 [ 2484  1624  2456  2398 10553]
3185 3190 [ 6329  9838 10103 10095  4050]
3190 3195 [12195 10815  5558 10778  8410]
3195 3200 [ 4755 11396   582 12180  1877]
3200 3205 [12655  2179  4418  8346  6803]
3205 3210 [11522 11389   118   470  4728]
3210 3215 [10729  6161  3208  1375  6589]
3215 3220 [ 3241  2969  7729  6491 12549]
3220 3225 [9044 8037 2188 6915 3019]
3225 3230 [5875 6237 3161 5434 5126]
3230 3235 [3849 6381 6099 4046 6544]
3235 3240 [2581 4290 8055 6143  229]
3240 3245 [11312  4355  4136 10929   852]
3245 3250 [ 2820  3523  8142  9038 10989]
3250 3255 [ 7193  3917    17 10911 11582]
3255 3260 [10045 10816  1530  8273 10498]
3260 3265 [12243  1515  9636   927  7477]
3265 3270 [ 9201  2542  6375 11887 10500]
3270 3275 [ 1476  2324 12302  2893  3430]
3275 3280 [9062 7333 6150 9489 8782]
3280 3285 [ 9685  5521 10185   366  6083]
3285 3290 [11652 10703  7848  1670  1803]
3290 3295 [ 6584 

4765 4770 [ 4643 11716 10360 10129  6986]
4770 4775 [10870  4828  9992  1937  1343]
4775 4780 [11466 10625  9706   854 11302]
4780 4785 [ 5674 11679  7323 11555    75]
4785 4790 [9416 9991  921 8774 1167]
4790 4795 [2402 3701 9757 3069 9937]
4795 4800 [ 8104 10857 10499  9409  3130]
4800 4805 [ 5869 10921 12866 12013  4232]
4805 4810 [2421 1799 8868 2805 3343]
4810 4815 [ 4756  1562 12024  3559   192]
4815 4820 [ 8373   576 10990  2569  8773]
4820 4825 [ 5259 11148 11092  3246  6866]
4825 4830 [12316 12451  4098 10803  7157]
4830 4835 [6614 5888 9138 8163 7094]
4835 4840 [10178  1793 10479 11936  4924]
4840 4845 [ 4936  4437  8415  3398 12522]
4845 4850 [ 9282 10003  1777  5277  7757]
4850 4855 [9788  962  617  799 9926]
4855 4860 [ 7791 12426  4568 12119  4035]
4860 4865 [10006  4285  3829  8374  4070]
4865 4870 [ 4560 11418  4782  3901  6817]
4870 4875 [3914 1411  835 1237 2000]
4875 4880 [10485  6219  3361 11261   559]
4880 4885 [10291  5854  6947 12470  4031]
4885 4890 [ 8092 11858

6355 6360 [8741 8279 6024 5010 2105]
6360 6365 [ 6033 11558  4319  9183  6857]
6365 6370 [11309 12708  7405 11133 10504]
6370 6375 [  853 11818  4572   742  1445]
6375 6380 [ 9944  7531 10974  4404  1381]
6380 6385 [ 1677  7783  6647  9335 11352]
6385 6390 [12203  9159  8891 11978 12606]
6390 6395 [ 5117  8762 12357  2490 11583]
6395 6400 [11864   679  1454  6715 12747]
6400 6405 [12377  9716  7956 11873 12339]
6405 6410 [ 9411   690  7867 12423  9280]
6410 6415 [7611  496 4636 2359 7245]
6415 6420 [ 4961 11661  6287  7315  5574]
6420 6425 [  443  5765 12689    13  5041]
6425 6430 [11242 12330  1405  1652 10873]
6430 6435 [ 2837  5712 12666  3352 11149]
6435 6440 [12484  1924 11058  8277  4648]
6440 6445 [ 9868 12172  1217 11798  6190]
6445 6450 [ 3722 10676  8197  2471  2643]
6450 6455 [12239 12329  9210 10882 11612]
6455 6460 [ 8993   754  8369  6816 12986]
6460 6465 [ 4655 11898  2887   189  2943]
6465 6470 [ 2497  1016 10621  4118  3600]
6470 6475 [ 4206 10246  3696  7668 12126]
64

7900 7905 [ 1534  2523  2829  8593 12550]
7905 7910 [ 5678  2405 11278  2493 10454]
7910 7915 [ 7327  9226  1830  3569 10442]
7915 7920 [10308 10948 12908  6758   557]
7920 7925 [  612 10993  6433  2486  7832]
7925 7930 [ 2695 10786  6295 11958  4224]
7930 7935 [ 4583  6121 10205 11034  2103]
7935 7940 [10785 12273  3646  1262  4895]
7940 7945 [9950 3414  646 6464 8099]
7945 7950 [7325 2403  808 1308 3330]
7950 7955 [4972 2255  304 8256 4595]
7955 7960 [ 4730  3737 11601  6407  3994]
7960 7965 [ 8978  4630  1345 10690  6270]
7965 7970 [ 9250   398  3601  6272 12617]
7970 7975 [3254  774 2385 8444 5836]
7975 7980 [10582  7507  1028  2828  8096]
7980 7985 [10096  7630  8746  8213  2487]
7985 7990 [ 4619  8643   447 10194  7679]
7990 7995 [ 2946  6725  7087 10287 11243]
7995 8000 [2875 3028 5124 4095 3049]
8000 8005 [ 4564 10397   327 11529   123]
8005 8010 [8031  340 4103 4588 4045]
8010 8015 [ 5658 11145  6882  7417  7803]
8015 8020 [ 3378  1409  7948  9563 12015]
8020 8025 [1853 6797 5

9505 9510 [ 6214   978  4158 10652 10408]
9510 9515 [ 9294  6827 12594   440  7690]
9515 9520 [ 359 9330 9933 4610 5076]
9520 9525 [7567 7716  893 6652 1141]
9525 9530 [  103  2452 11590  6644 11169]
9530 9535 [ 1618 11423   824  5597  5313]
9535 9540 [10779 12287  3220  1132  3285]
9540 9545 [9615  867 6202 6331 2372]
9545 9550 [ 8625  3751  8080 12317   446]
9550 9555 [  490 12867  3351 11434  3408]
9555 9560 [ 4393  4179   654 11684 10963]
9560 9565 [  246  8733 12802  9547  9206]
9565 9570 [ 9074 10941  1890 12063  5499]
9570 9575 [ 4699  7966  6611 11870 12128]
9575 9580 [12240  7387 10781  7313  7949]
9580 9585 [ 1516  4718 12796  9193  6981]
9585 9590 [10122  2264 11719  3653  2062]
9590 9595 [12140  3355  4017   735  7853]
9595 9600 [ 8789  3096 11248  1463  4120]
9600 9605 [4444  980 5205 5187 8617]
9605 9610 [ 8223  5677  5604 12806  1174]
9610 9615 [ 3784 10496  3947  7491  9161]
9615 9620 [ 3865 11317 11537  9202 10369]
9620 9625 [ 8601 12397  8788 10660  6718]
9625 9630 [8

11120 11125 [ 7677  9069   892  7463 12640]
11125 11130 [ 8325 12707  2719 11843   242]
11130 11135 [ 8777  1794  4548 10004   509]
11135 11140 [4469  822 3793 9856 5212]
11140 11145 [ 7214  1407 10971   536  7861]
11145 11150 [ 2197 10145  1519  3890  7972]
11150 11155 [ 2866 12854  3295  4296 12631]
11155 11160 [12438  2004  2226  1931  9618]
11160 11165 [11103  1093  2170  6074  7721]
11165 11170 [11236 12673  7614  7865  8150]
11170 11175 [10615 11141  8810  1739 11206]
11175 11180 [  68 1784 7096 7135 9797]
11180 11185 [ 3643 11778 10051  3474  9529]
11185 11190 [12399 12236  8982  9200  7588]
11190 11195 [5482 1303 1549 2017 9599]
11195 11200 [12056 11003  2266 10954  7944]
11200 11205 [  872  6656 11708 11281  3026]
11205 11210 [5432 2713 7467 5190 9782]
11210 11215 [2344 5231 9724 1330 1764]
11215 11220 [ 3636   251 11016 11437  2204]
11220 11225 [12046  6369  2097  7155  5213]
11225 11230 [10347  9072  1880 11321 11835]
11230 11235 [  405  3931  2356 11748  5380]
11235 11240 [

12920 12925 [ 4830  1423  5848  8020 12943]
12925 12930 [4711 9008  974 2561 9424]
12930 12935 [12555 11481  5139  6218 10393]
12935 12940 [ 6748  7622  4798 11507  9012]
12940 12945 [ 8749  3177  2774  2658 10020]
12945 12950 [  265 10895   667  6465   124]
12950 12955 [12280   710  3125   995 11091]
12955 12960 [12624   434  4342  1851  4729]
12960 12965 [ 3805  4537  1401 11793 12313]
12965 12970 [ 9143  3289 11847  1976 12368]
12970 12975 [ 1679  6815  1204  9504 10305]
12975 12980 [12113  2661  8483  5473  4353]
12980 12985 [ 8799  3446  9813  1195 10441]
12985 12990 [ 9617  7250  5168  3547 11074]
12990 12992 [2916 2758]
0 5 [ 6034  1715  9735   227 11873]
5 10 [10062   822  5624  6191  9840]
10 15 [ 2984  1564  6959   109 12679]
15 20 [ 7612 11487  6482    52  9882]
20 25 [12898  7156 10634 12990  6373]
25 30 [9728 4665 7319 5632 8545]
30 35 [11883  2014  4741 11930  5857]
35 40 [2610 9291  328 5977 2691]
40 45 [12853  5089  8343  5574   761]
45 50 [10610 12682  9219  9643  7070

1570 1575 [  851 12088  2525  3839  6564]
1575 1580 [ 7268 12600  4568   445    62]
1580 1585 [  160  5980 11176 11326  3988]
1585 1590 [5316 6297 7082 7328 3313]
1590 1595 [ 9139 11396 11734  2938 10601]
1595 1600 [   71  7675 12675 12234 11270]
1600 1605 [7859 8052 8510 2434 5476]
1605 1610 [12541  4547  7521 12161  8303]
1610 1615 [ 4461  6616   775 11866 10302]
1615 1620 [ 7164  1083 11291 10945 10599]
1620 1625 [ 7945 11676   487  4152  4784]
1625 1630 [ 2616 10801  2737  8409  5734]
1630 1635 [ 5885  2547 12211  3842  8261]
1635 1640 [1770 1359 4454 3425 8393]
1640 1645 [ 5434  7187   615 10250  2287]
1645 1650 [ 1762  7278 10639  7088  7561]
1650 1655 [5244 4483 1136 7211 4412]
1655 1660 [4491 8564 5391  575 7625]
1660 1665 [ 4642  3430  3217 11614  6150]
1665 1670 [ 6613  1534  1256 10078  7417]
1670 1675 [ 3073 12946  6076  3776  8011]
1675 1680 [10114  2909  6883  9070  3530]
1680 1685 [11392  3226  1437 10464  7185]
1685 1690 [7479 6570 6528 4971 5532]
1690 1695 [ 7269 10616

3245 3250 [ 5344  4111  8929 11885  5593]
3250 3255 [11423  8340  7243  3205  9286]
3255 3260 [2274 1570 8356 4478 6948]
3260 3265 [4791 4588  770    6 8108]
3265 3270 [ 5711  3037 12628  3745  8930]
3270 3275 [ 6128    44 10363  1027  5403]
3275 3280 [6161 1252 6001 6312 9760]
3280 3285 [9581 2217 7333  329 7672]
3285 3290 [ 6000 11499  3320  2686  3607]
3290 3295 [ 5274 12107  6326  4939   485]
3295 3300 [ 5156  6180 12037  3447  5382]
3300 3305 [11671 12829  1632  4041  1883]
3305 3310 [7689  617  562 6381 3366]
3310 3315 [ 3891  6680  7487  6015 12668]
3315 3320 [ 6977 12548  1963 12735  4250]
3320 3325 [2444  870 2842 8992 9501]
3325 3330 [  900  6643  4283 12473 12470]
3330 3335 [ 5124  7743  3814  8840 11647]
3335 3340 [ 8797  3700  2516  9400 12133]
3340 3345 [ 2061 10894  4290  4347  8294]
3345 3350 [11427  5071  2565  4004  7935]
3350 3355 [ 7279 11002  5052  5522  7069]
3355 3360 [10612 11300 10531  8957 10440]
3360 3365 [ 7084 10015  5435 12836  3660]
3365 3370 [ 7146  7610

4840 4845 [ 9288   983  5952  3109 12206]
4845 4850 [  210  6495  2160  5569 12111]
4850 4855 [10432  2742  6755  3531  4369]
4855 4860 [  751  5908  1434   585 12713]
4860 4865 [9047 2575 9100 3062 8800]
4865 4870 [ 6702  5625 10445  8056   128]
4870 4875 [9852 1659 8915 8316 3352]
4875 4880 [11379   522  8995 11129  7067]
4880 4885 [8165 6607 4191 6351 9105]
4885 4890 [10199  2224 11768 10212  7832]
4890 4895 [11116  9685 12247  5643  8348]
4895 4900 [9815 9090 1409 7125 3940]
4900 4905 [11522  9717 12084 10989  9872]
4905 4910 [ 7797 12013 11334  8370 10117]
4910 4915 [ 3981  3828  5620  6717 12056]
4915 4920 [12573  8698  3747  3980  4838]
4920 4925 [ 1870 12143 10662   723  6308]
4925 4930 [9615 2885 2354 4923  893]
4930 4935 [ 1473  2136  3959 12243 12529]
4935 4940 [ 4600  2829  3054  5290 10943]
4940 4945 [12178   578 12263  4536 10788]
4945 4950 [ 4698  3580 11540  5815  3264]
4950 4955 [12157 12776 12770 11320 11031]
4955 4960 [ 6683  1336  9871 10240  1985]
4960 4965 [4210 9

6500 6505 [ 3770  4819 11860   400  2735]
6505 6510 [10812  1960  6815  5309  2727]
6510 6515 [ 2438  2394 11735  3212 11390]
6515 6520 [ 3015  3567 12984  4381  1649]
6520 6525 [ 672 2770 4582 4961 6825]
6525 6530 [ 6650  3899  9520 11772   709]
6530 6535 [7855  806  121 8540  372]
6535 6540 [ 947 3207 8439 1990 9779]
6540 6545 [12635 11369  2148  4444  5331]
6545 6550 [ 1390  1935   374 12653  4378]
6550 6555 [  557  8268 12980 12012  7359]
6555 6560 [ 7723 12416   173 12771 10045]
6560 6565 [ 4888  5280  2563 11275  2670]
6565 6570 [ 7220 12315  8300 12468 11007]
6570 6575 [12333 10364  7768  5696  9942]
6575 6580 [11272  6223  2452  5275  9894]
6580 6585 [ 6238  2646  9174  4764 12556]
6585 6590 [ 5858 11348  9687  6207  1688]
6590 6595 [6298 8100 4580  531 1907]
6595 6600 [3689 2227 6615 1593 3803]
6600 6605 [3357 9713 9183 7292 3733]
6605 6610 [11740  1448  9757  7860 12777]
6610 6615 [ 2009  4395   369 10265  3112]
6615 6620 [ 6288  5878 12030 12630 11549]
6620 6625 [  749  3329

8045 8050 [12694 11466  9559  5861 10600]
8050 8055 [ 5954  9042  3433  4521 11681]
8055 8060 [ 1240  6661  3438  7064 11659]
8060 8065 [9094 3665 4792 7653 8287]
8065 8070 [11425  1708  8876  6949  8213]
8070 8075 [ 7001  7910  7330   664 12758]
8075 8080 [ 1643  6990 12763  3077  9771]
8080 8085 [12751   456  2004  1847  5718]
8085 8090 [ 9339 11716  8513 10414 12096]
8090 8095 [12014  1014  5278  3230  9026]
8095 8100 [10803  9396 10023  3965  6485]
8100 8105 [  543 10109   482  4707  3913]
8105 8110 [ 1756 11730  6244  2184 12782]
8110 8115 [ 6067  6304  1427 12377  8732]
8115 8120 [ 8814  9924  9058  3096 11584]
8120 8125 [8136 4376 6111 7349 8834]
8125 8130 [ 8718  7203  6640 10721  6251]
8130 8135 [ 9654  7504 10099  4584  6096]
8135 8140 [ 5752 11505  4843  1618  8678]
8140 8145 [ 6437 10942  1849  5137 10278]
8145 8150 [ 8464  8126 12636  2389 10434]
8150 8155 [ 272 8187 9098 8430 2799]
8155 8160 [ 6182  1266  7541  4405 10532]
8160 8165 [12492  3550  2240 10037  5085]
8165 81

9570 9575 [10825  6810  8333  2252   635]
9575 9580 [12944  4323  3429  3046  7902]
9580 9585 [   76 10833  5120  6241    42]
9585 9590 [ 504 3032 9103 4488 9821]
9590 9595 [10211  7255  1611  1148  5955]
9595 9600 [7695 9202 9896 3878 9144]
9600 9605 [7302 3337 8425 1053 2393]
9605 9610 [ 8746 10844  5779 12657  6062]
9610 9615 [  917   353 10549  2497  4175]
9615 9620 [12974  5575  7817   346   955]
9620 9625 [4740 6459 7548 2750 9062]
9625 9630 [10892 10314  3168   871   771]
9630 9635 [ 4636  8297  6571 12544  9444]
9635 9640 [10083  1848  8242  6770  8042]
9640 9645 [ 638 7736 8346 8886 4138]
9645 9650 [ 8188 10849  5339  8649  9839]
9650 9655 [ 4717  3640  6264 10384  7579]
9655 9660 [ 6864 11765  7976  5927 11685]
9660 9665 [7072 5788 6008 1769 9649]
9665 9670 [12400  2355  3763 12918  9059]
9670 9675 [ 3305 12388  9695  3543  6110]
9675 9680 [ 3133  2333 10087  3398   921]
9680 9685 [12861  1057  8955  6050  1410]
9685 9690 [12913  3969  9427  7120 12065]
9690 9695 [10828  6841

11085 11090 [ 7259 12975  8979  2386  4944]
11090 11095 [12603 10635  1892  3882 11809]
11095 11100 [5783 5258 8127 6912 2494]
11100 11105 [  207  8810  5000 11501  9507]
11105 11110 [6965 7761 4703 4123 1183]
11110 11115 [ 9729  2766  1247 12710  8229]
11115 11120 [11333  1984 10198  4504  8329]
11120 11125 [ 3933  2926 10039 12967  9387]
11125 11130 [11194 10496  4505  3642  2399]
11130 11135 [ 8058 12742  5151   438 10293]
11135 11140 [ 8132  5656 11599  7656  8007]
11140 11145 [ 1918 10548 10787   621  2551]
11145 11150 [ 5001 12304  4826  1352  9335]
11150 11155 [10279   978   652  9660  9530]
11155 11160 [ 7595 11285 12830  9795 12754]
11160 11165 [7169 6290 8761 8596 5800]
11165 11170 [  759  3716   472 12171  6127]
11170 11175 [12732 11924  4178  7769 11411]
11175 11180 [11083  7046  9656  2948  9891]
11180 11185 [ 2064  3709  2374 10481  7942]
11185 11190 [1191 5352   11 4030 2142]
11190 11195 [ 2764  7588  7537  1914 10223]
11195 11200 [ 4553  1389 11485  5547  6951]
11200 11

12565 12570 [4900  390 2150 5090  431]
12570 12575 [12900 10072  6737  9593  7184]
12575 12580 [10922  2253  2860  5193  6450]
12580 12585 [9537  101 7715 7908 5682]
12585 12590 [ 3166 10161  7830  6626  9276]
12590 12595 [   14 12459  4363 11278  7835]
12595 12600 [ 5221  8931 12826  7807 10340]
12600 12605 [ 5838 12381  4873  1322  4169]
12605 12610 [ 8002 11957  6490  4026  4759]
12610 12615 [4143 7339 6221 8272  970]
12615 12620 [11791 12707  2181 12912  3422]
12620 12625 [2749 4338 3596 8137 4864]
12625 12630 [6820  294 7522 4078 4453]
12630 12635 [11280  7961 12618 11615  6732]
12635 12640 [12875 11704  4074 10842 12792]
12640 12645 [ 3739 11276    27  6662  4239]
12645 12650 [ 1332  1606 10495  4029  2661]
12650 12655 [8418 3255 3708 3159 3058]
12655 12660 [ 9877 11976  9782 12277  1703]
12660 12665 [11558 12266 11090  8880 10334]
12665 12670 [11317  9279  9542  9968 12409]
12670 12675 [12462  4251 10457  3823  7712]
12675 12680 [ 123 6315 4189 6401 9993]
12680 12685 [ 7403  700

In [55]:
while batch_generator.idx_generators[BatchGenerator.TRAIN_EVAL].epochs_completed < 2:
    batch_generator.get_batch(BatchGenerator.TRAIN_EVAL, debug=True)

0 1299 [ 5313 10613  1359 ... 11203  1002 10704]
0 1299 [ 5313 10613  1359 ... 11203  1002 10704]


In [56]:
batch_generator._generate_partition(batch_generator.path_partitions[np.array([0, 1, 2, 3, 4, 5, 10, 6])])

array([[ 0.,  1.],
       [ 1.,  4.],
       [ 4.,  5.],
       [ 5., 11.],
       [11., 17.],
       [17., 19.],
       [19., 20.],
       [20., 21.]])

In [39]:
pd.concat([train_dataset.loc[train_dataset['label']==1], train_dataset.loc[train_dataset['label']==0][:3]])

Unnamed: 0,id,source,target,relation,entity_paths,relation_paths,label
0,MH_train_0,DB00072,DB00773,interacts_with,"[[DB00072, P16104, DB00773]]","[[2822, 13427, -1]]",1
9,MH_train_1,DB06822,DB09079,interacts_with,"[[DB06822, P48061, P15692, DB09079], [DB06822, P61073, P15692, DB09079]]","[[8149, 8092, 4068, -1], [8149, 8092, 4068, -1]]",1
18,MH_train_2,DB00341,DB00083,interacts_with,"[[DB00341, P35367, P01375, P60880, DB00083], [DB00341, P35367, P05231, P60880, DB00083]]","[[9869, 7118, 9141, 12407, -1], [9869, 7118, 3127, 12407, -1]]",1
27,MH_train_3,DB01171,DB00083,interacts_with,"[[DB01171, P21397, P60880, DB00083]]","[[14165, 1645, 393, -1]]",1
36,MH_train_4,DB01050,DB06813,interacts_with,"[[DB01050, P35354, P00533, DB06813], [DB01050, P35354, P00374, DB06813]]","[[14217, 2353, 12020, -1], [14217, 12117, 12020, -1]]",1
45,MH_train_5,DB01200,DB06288,interacts_with,"[[DB01200, P41595, P34969, DB06288]]","[[741, 11922, 3834, -1]]",1
54,MH_train_6,DB00203,DB00862,interacts_with,"[[DB00203, DB00862]]","[[5158, -1]]",1
63,MH_train_7,DB01656,DB00773,interacts_with,"[[DB01656, P48444, P15692, DB00773], [DB01656, P48444, P16104, DB00773], [DB01656, P48444, Q13315, DB00773], [DB01656, DB02527, Q13315, DB00773]]","[[2400, 7634, 8080, -1], [2400, 8911, 537, -1], [2400, 1968, 537, -1], [2400, 12464, 537, -1]]",1
72,MH_train_8,DB04844,DB01233,interacts_with,"[[DB04844, Q05940, P14416, DB01233], [DB04844, Q01959, P14416, DB01233]]","[[12088, 6616, 3417, -1], [12088, 6616, 3417, -1]]",1
81,MH_train_9,DB01182,DB00277,interacts_with,"[[DB01182, Q12809, P01160, Q92769, DB00277]]","[[13060, 13776, 10681, 12572, -1]]",1
