In [1]:
import pandas as pd
import numpy as np
import pickle
import re

import nltk
nltk.download('punkt')

import spacy
# python -m spacy download en_core_web_lg
nlp = spacy.load("en_core_web_lg")

[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
%%time
text_dict = pd.read_csv('data/intro_text.csv', dtype={'text':str, 'text_id':'int32'}).set_index('text_id').text.to_dict()
display(len(text_dict))
text_dict[12]

5343565

CPU times: user 27.4 s, sys: 2.36 s, total: 29.8 s
Wall time: 39.4 s


"Anarchism is an anti-authoritarian political and social philosophy that rejects hierarchies deemed unjust and advocates their replacement with self-managed, self-governed societies based on voluntary, cooperative institutions. These institutions are often described as stateless societies, although several authors have defined them more specifically as distinct institutions based on non-hierarchical or free associations. Anarchism's central disagreement with other ideologies is that it holds the state to be undesirable, unnecessary, and harmful. Anarchism is usually placed on the far-left of the political spectrum, and much of its economics and legal philosophy reflect anti-authoritarian interpretations of communism, collectivism, syndicalism, mutualism, or participatory economics. As anarchism does not offer a fixed body of doctrine from a single particular worldview, many anarchist types and traditions exist and varieties of anarchy diverge widely. Anarchist schools of thought can di

In [3]:
%%time
entity_df = pd.read_csv('data/intro_entity.csv', 
                        dtype={'entity':str, 'page_id':'int32', 'item_id':'int32', 'text_id':'int32'})
entity_df

CPU times: user 20.1 s, sys: 1.2 s, total: 21.3 s
Wall time: 25.7 s


Unnamed: 0,entity,page_id,item_id,text_id
0,anti-authoritarian,867979,1030234,12
1,political,23040,179805,12
2,social philosophy,586276,180592,12
3,hierarchies,13998,188619,12
4,self-managed,40949353,15981562,12
...,...,...,...,...
35840002,Carl Randall,40277554,16215506,62473330
35840003,The World Ends With You,6987282,1416303,62473330
35840004,2016 Summer Olympics closing ceremony,44593137,18741083,62473330
35840005,2020 Summer Olympics,1610886,181278,62473330


In [4]:
%%time
#item_dict = pickle.load(open("data/item_dict.p", "rb"))
item_dict = pd.read_feather('data/item_dict.ftr', use_threads=True).set_index('en_label').item_ids.to_dict()
display(len(item_dict))
item_dict['tesla']

48191954

CPU times: user 1min 4s, sys: 8.52 s, total: 1min 13s
Wall time: 1min 21s


array([    9036,   163343,   210893,   478214,   622424,   765530,
         780348,  1050485,  1428953,  1463050,  1548225,  1634161,
        2384079,  2406220,  3982823,  5172712,  7035686,  7705502,
        7705506,  7705515, 16258100, 19565583, 19845823, 23663332,
       27701406, 31803712, 37251206, 56084926])

In [5]:
def get_id(entity):
    try:
        item = item_dict[entity]
        return item
    except:
        if len(entity.strip('\'" ')) < len(entity):
            return get_id(entity.strip('\'" '))
        elif entity[:4] == 'the ':
            return get_id(entity[4:])
        else:
            return []
    
def get_ids(entities):
    return [get_id(i.strip().casefold()) for i in entities]

print(get_ids(['Tesla', '"The Tesla', 'Teslarati', 'The Teslarati']))

[array([    9036,   163343,   210893,   478214,   622424,   765530,
         780348,  1050485,  1428953,  1463050,  1548225,  1634161,
        2384079,  2406220,  3982823,  5172712,  7035686,  7705502,
        7705506,  7705515, 16258100, 19565583, 19845823, 23663332,
       27701406, 31803712, 37251206, 56084926]), array([    9036,   163343,   210893,   478214,   622424,   765530,
         780348,  1050485,  1428953,  1463050,  1548225,  1634161,
        2384079,  2406220,  3982823,  5172712,  7035686,  7705502,
        7705506,  7705515, 16258100, 19565583, 19845823, 23663332,
       27701406, 31803712, 37251206, 56084926]), [], []]


In [6]:
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')

In [7]:
def get_sentence(row):
    entity, page_id, item_id, text_id = row
    text = text_dict[text_id]
    
    if text.find(entity) == -1:
        raise Exception(f'[{entity}] was not found in text:\n {text}')
    
    entity_start = text.find(entity)
    entity_end = entity_start + len(entity)
    sentence_start = 0
    sentence_end = len(text)
    
    #splits = [i.start() for i in re.finditer(r"(?<=[(\.|!|\?)])\s(?![a-z])", text)]
    splits = [s for s, e in sent_detector.span_tokenize(text)]
    
    for i, j in enumerate(splits):
        if j <= entity_start:
            sentence_start = j
        elif j >= entity_end:
            sentence_end = j
            break
    #sentence = text[sentence_start:sentence_end]
    doc = nlp(text[sentence_start:entity_start] + text[entity_end:sentence_end])
    
    entity_list = [entity.strip().casefold()]
    for ent in doc.ents:
        #print(ent.text, ent.label_) #ent.start_char, ent.end_char, 
        if ent.label_ == 'PERSON' and len(ent.text.split()) > 2:
            try:
                item_dict[ent.text.lower()]
            except:
                entity_split = ent.text.lower().split()
                entity_list.append(entity_split[0] + ' ' + entity_split[-1])

    entity_list = pd.unique(entity_list + [i.text.strip().casefold() for i in doc.ents if i.label_ not in ['DATE']])
    entity_id_list = get_ids(entity_list)
    return entity_id_list, entity_list, (entity, page_id, item_id)
    
[get_sentence(i) for i in entity_df[:5].values]

[([array([   6229, 1030234])],
  array(['anti-authoritarian'], dtype=object),
  ('anti-authoritarian', 867979, 1030234)),
 ([array([    6216,     7163,     7169,     7257,     7278,     8683,
             11268,    25107,    27778,    28108,    30849,    36442,
             42388,    56061,    80330,    82955,   122131,   131160,
            133136,   148837,   159385,   159493,   166542,   168559,
            179805,   188958,   191320,   191600,   203764,   204886,
            214183,   217105,   265147,   276861,   277569,   303132,
            330963,   333024,   392160,   442725,   466439,   496276,
            502144,   517372,   535347,   564824,   568452,   599048,
            622291,   635322,   669701,   678363,   699375,   706243,
            707438,   745692,   829147,   831058,   836655,   847301,
            855726,   865455,   871419,   875894,   884987,   900406,
            904852,   917351,   934744,   950320,   963402,   969663,
            986069,  1018769,  1048091

In [8]:
import multiprocessing as mp
n_cores = mp.cpu_count()
n_cores

16

In [9]:
%%time
samples = entity_df.sample(250000, random_state=1).to_numpy()
samples

CPU times: user 2.02 s, sys: 95.8 ms, total: 2.11 s
Wall time: 2.11 s


array([['Military Cross', 488249, 1335064, 187626],
       ['Wales', 69894, 25, 42507907],
       ['Joe Conforte', 11674164, 6209172, 1051658],
       ...,
       ['Manchester City F.C', 165813, 50602, 761536],
       ['Christian', 5211, 5043, 7857964],
       ['Chicago, Illinois', 6886, 1297, 6199675]], dtype=object)

In [10]:
%%time
# ~10 seconds per 1000 samples with 16 cores
p = mp.Pool(n_cores-2)
sentences = p.map(get_sentence, samples)
display(len(sentences))
sentences

250000

CPU times: user 11.8 s, sys: 3.66 s, total: 15.4 s
Wall time: 28min 36s


[([array([  386083,   869896,  1335064,  1788804,  2727598,  6852012,
          12090795])],
  array(['military cross'], dtype=object),
  ('Military Cross', 488249, 1335064)),
 ([array([      25,      145,      181,     3224,     9309,    10690,
             18996,    26010,    80043,   180729,   180857,   185692,
            188353,   209352,   218555,   229654,   275173,   281113,
            302442,   313262,   374194,   482546,   493517,   545607,
            599661,   658888,   666063,   683738,   731669,   749686,
            807103,   816814,   822877,   832685,   834536,   867913,
            946297,   975380,   980162,   988326,  1028218,  1041261,
           1063608,  1066586,  1070187,  1072029,  1077282,  1089788,
           1132110,  1135773,  1138780,  1274559,  1286223,  1361101,
           1478550,  1483510,  1521253,  1605149,  1646792,  1752400,
           1846384,  1892770,  2000845,  2038022,  2075835,  2075956,
           2113023,  2243531,  2249932,  2316274,  231

In [11]:
%%time
pickle.dump(sentences, open('data/dataset.p', 'wb'), protocol=pickle.HIGHEST_PROTOCOL)

CPU times: user 8.18 s, sys: 817 ms, total: 9 s
Wall time: 8.85 s
