In [None]:
import pandas as pd
from tqdm import tqdm
import spacy
import neuralcoref
import re
from newspaper import Article
from googlesearch import search
import nltk
import json
import itertools
import tldextract
import itertools
from zeroshot_topics import ZeroShotTopicFinder
from transformers import AutoTokenizer
from zero_shot_re import RelTaggerModel, RelationExtractor
from polyfuzz import PolyFuzz
from transformers import pipeline
from neo4j import GraphDatabase
import hashlib
from collections import Counter

In [None]:
## import various nlp models for classifying data
nlp = spacy.load('en')
neuralcoref.add_to_pipe(nlp)
nlp_two =spacy.load('en_core_web_sm')
sentiment = pipeline(task = 'sentiment-analysis')
zsmodel = ZeroShotTopicFinder()
model = PolyFuzz("TF-IDF")
model_two = RelTaggerModel.from_pretrained("fractalego/fewrel-zero-shot")
tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad")

In [None]:
## function to download articles 
## these could be amended to capture other data reddit/twitter/other social media
## other text preprocessing can take place here as needed
## tldextract extracts the root of the website so we document the sources

def googlescrape(topic_name, num_results):
    results = search(topic_name, num_results)
    
    sources = []
    
    
    for r in results:
        try:
            url = r
            article = Article(url, config=config)
            article.download()
            article.parse()
            text = article.text
            text = re.sub(r'http\S+', '', text) #remove urls
            text = re.sub("([\(\[]).*?([\)\]])", "\g<1>\g<2>", text) #remove bracketed text
            text = text.replace('()', "")
            text = text.replace('[]', "")        
            t = text.encode('ascii', 'ignore').decode() #remove unusual charachters
            t = re.sub("https*\S+", " ", t)
            temp = []
            for tem in t.split('\n\n'):
                if len(tem.split(" ")) > 8:
                    temp.append(tem)
            if len(" ".join(temp)) < 500:
                sources.append([t, " ".join(temp), t.split('\n\n'), tldextract.extract(r).domain])
        except:
            pass
    
    return sources

In [None]:
## change for your search term and number of articles
data = googlescrape(yoursearchtermhere, 15)

In [None]:
## enrich articles with sentiment analysis and zero-shot topics and creates a dataframe

dflist = []
print(len(data))
for d in tqdm(data):
    try:
        topiclist = []
        sentiment_list = []
        for x in d[2]:
            try:
                for y in zsmodel.find_topic(x, n_topic=2):
                    topiclist.append(y)
                s = sentiment(x)
                s = s[0]
                sentiment_list.append(s['label'])
            except:
                pass

        topic = max(set(topiclist), key=topiclist.count)
        sentiment_value = max(set(sentiment_list), key=sentiment_list.count)

        dflist.append([d[0], d[1], d[2], d[3], topic, sentiment_value])
    except:
        pass

df = pd.DataFrame(dflist, columns=['text', "text_no_n", 'text_list', 'domain', 'topic', 'sentiment'])

In [None]:
## extracts entity pairs from text

def get_entity_pairs(text, coref=True):
    # preprocess text
    text = re.sub(r'\n+', '.', text)  # replace multiple newlines with period
    text = re.sub(r'\[\d+\]', ' ', text)  # remove reference numbers
    text = nlp(text)
    if coref:
        text = nlp(text._.coref_resolved)  # resolve coreference clusters

    def refine_ent(ent, sent):
        unwanted_tokens = (
            'PRON',  # pronouns
            'PART',  # particle
            'DET',  # determiner
            'SCONJ',  # subordinating conjunction
            'PUNCT',  # punctuation
            'SYM',  # symbol
            'X',  # other
        )
        ent_type = ent.ent_type_  # get entity type
        if ent_type == '':
            ent_type = 'NOUN_CHUNK'
            ent = ' '.join(str(t.text) for t in
                           nlp(str(ent)) if t.pos_
                           not in unwanted_tokens and t.is_stop == False)
        elif ent_type in ('NOMINAL', 'CARDINAL', 'ORDINAL') and str(ent).find(' ') == -1:
            refined = ''
            for i in range(len(sent) - ent.i):
                if ent.nbor(i).pos_ not in ('VERB', 'PUNCT'):
                    refined += ' ' + str(ent.nbor(i))
                else:
                    ent = refined.strip()
                    break

        return ent, ent_type

    sentences = [sent.string.strip() for sent in text.sents]  # split text into sentences
    ent_pairs = []
    for sent in sentences:
        sent = nlp(sent)
        spans = list(sent.ents) + list(sent.noun_chunks)  # collect nodes
        spans = spacy.util.filter_spans(spans)
        with sent.retokenize() as retokenizer:
            [retokenizer.merge(span, attrs={'tag': span.root.tag,
                                            'dep': span.root.dep}) for span in spans]
        deps = [token.dep_ for token in sent]

        # limit our example to simple sentences with one subject and object
        if (deps.count('obj') + deps.count('dobj')) != 1\
                or (deps.count('subj') + deps.count('nsubj')) != 1:
            continue

        for token in sent:
            if token.dep_ not in ('obj', 'dobj'):  # identify object nodes
                continue
            subject = [w for w in token.head.lefts if w.dep_
                       in ('subj', 'nsubj')]  # identify subject nodes
            if subject:
                subject = subject[0]
                # identify relationship by root dependency
                relation = [w for w in token.ancestors if w.dep_ == 'ROOT']
                if relation:
                    relation = relation[0]
                    # add adposition or particle to relationship
                    if relation.nbor(1).pos_ in ('ADP', 'PART'):
                        relation = ' '.join((str(relation), str(relation.nbor(1))))
                else:
                    relation = 'unknown'

                subject, subject_type = refine_ent(subject, sent)
                token, object_type = refine_ent(token, sent)

                ent_pairs.append([str(subject), str(relation), str(token),
                                  str(subject_type), str(object_type)])

    ent_pairs = [sublist for sublist in ent_pairs
                          if not any(str(ent) == '' for ent in sublist)]
    pairs = pd.DataFrame(ent_pairs, columns=['subject', 'relation', 'object',
                                             'subject_type', 'object_type'])
    print('Entity pairs extracted:', str(len(ent_pairs)))

    return pairs

In [None]:
## creates a column of entities extracted from text

text_list = df['text_no_n'].tolist()
entity_column = []

for t in text_list:
    text = nlp(t)
    text = nlp(text._.coref_resolved)
    shortlist = []
    for ent in text.ents:
        if ent.label_ in ['GPE', 'ORG', 'PERSON', 'NORP', 'MONEY', 'EVENT', 'WORK_OF_ART', 'FAC', 'NOUN_CHUNK']:
            shortlist.append(ent.text)
    
    entity_column.append(shortlist)

all_entities = list(set([item for sublist in entity_column for item in sublist]))
all_entities.sort(key=lambda s: len(s))
entity_column_corrected = []
for e in entity_column:
    entity_column_corrected.append(list(set(e)))
df['entities'] = entity_column_corrected

In [None]:
## zero_shot extraction sometimes creates multiple names for the same entity despite coreference efforts
# polyfuzz is a model which can help detect similar items and group the together
model.match(all_entities)
similar_words_df = model.get_matches()[model.get_matches().Similarity > 0.7]
emptycounter = Counter()
for x in [Counter(l) for l in df['entities'].tolist()]:
    emptycounter.update(x)

common_ents = dict(emptycounter)

similar_dictionary = {}
for index, row in similar_words_df.iterrows():
    
    if common_ents[row['From']] > common_ents[row['To']]:
        similar_dictionary[row['From']] = row['To']
    else:
        if len(row['From']) < len(row['To']):
            similar_dictionary[row['From']] = row['To']
        else:
            similar_dictionary[row['To']] = row['From']

In [None]:
## remove duplicates from entity column
def deduplicator(ents):
    new_ents = []
    for x in ents:
        
        if x in similar_dictionary and ("£" or "$") not in dict(short_counter):
            new_ents.append(similar_dictionary[x])
        else:
            new_ents.append(x)
    return new_ents


df['entities_corrected'] = df.apply(lambda x: deduplicator(x['entities']), axis=1)

In [None]:
## set up neo4j sandbox connection

host = 'bolt://xx.xx.xx.xx:xxxx'
user = 'xxxx'
password = 'xxxx-xxxx-xxxx'
driver = GraphDatabase.driver(host,auth=(user, password))

def neo4j_query(query, params=None):
    with driver.session() as session:
        result = session.run(query, params)
        return pd.DataFrame([r.values() for r in result], columns=result.keys())

In [None]:
## create hash column for display purposes
df['hash'] = df.apply(lambda x: hash(x['text']), axis=1)

In [None]:
## upload topics, domains, hashes and their relationships to neo4j sandbox instance

for index, row in df.iterrows():
    
    neo4j_query("""
    MERGE (a:Centre{centre:$centre})
    MERGE (b:Topic{topic:$topic})
    MERGE (c:Domain{domain:$domain})
    MERGE (d:Hash{hash:$hash})
    MERGE (a)-[:TOPIC]->(b)
    MERGE (c)-[:COVERS]->(b)
    MERGE (c)-[:MENTIONS]->(d)
    """, {'centre':'1MDB', 'topic':row['topic'], 'domain':row['domain'], 'hash':row['hash']})

In [None]:
#link extracted entities to the article where they were mentioned

for index, row in df.iterrows():
    x = row['entities_corrected']
    for y in x:
        if y != '':
            neo4j_query("""
            MERGE (x:Entity{entity:$entity})
            MERGE (d:Hash{hash:$hash})
            MERGE (d)-[:MENTIONS]->(x)
            """, {'entity': y, 'hash':row['hash']})

In [None]:
## zero_shot relations to extract 

relations = ['linked'] ## this can be a list and any relationship you can think of however some work better than others
extractor = RelationExtractor(model_two, tokenizer, relations)


In [None]:
## chunk text to be processed by model_two

def get_chunks(s, maxlength):
    start = 0
    end = 0
    while start + maxlength  < len(s) and end != -1:
        end = s.rfind(" ", start, start + maxlength + 1)
        yield s[start:end]
        start = end +1
    yield s[start:]


In [None]:
## assemble list of assessed predictions from text

predicted_rels = []
for x in tqdm(range(0, len(df['text_no_n'].tolist()) - 1)):
    for paragraph in get_chunks(df['text_no_n'].tolist()[x], 512):
        combinations = list(itertools.combinations(list(set(df.iloc[x]['entities_corrected'])), 2))
        try:
            temp_df = get_entity_pairs(paragraph)
            temp_df = temp_df[temp_df["subject_type"].str.contains("NOUN_CHUNK")==False] ## relationships with NOUN_CHUNK had high false positive rates 
            temp_df = temp_df[temp_df["object_type"].str.contains("NOUN_CHUNK")==False]
            for index, row in temp_df.iterrows():
                print(row['subject'])
                print(row['object'])
                predicted_rels.append({'head': row['subject'], 'tail': row['object'], 'type':row['relation'], 'source':str(df.iloc[x]['hash'])})
        except:
            continue
        
        for combination in list(combinations):
            try:
                ranked_rels = extractor.rank(text=paragraph.replace(",", " "), head=combination[0], tail=combination[1])
                if ranked_rels[0][1] > 0.8:
                    print(combination)
                    print(ranked_rels)
                    predicted_rels.append({'head': combination[0], 'tail': combination[1], 'type':ranked_rels[0][0], 'source':str(df.iloc[x]['hash'])})

            except:
                pass

In [None]:
## upload predicted relationships

neo4j_query("""
UNWIND $data as row
MERGE (x:Entity{entity:row.head})
MERGE (y:Entity{entity:row.tail})
MERGE (d:Hash{hash:row.source})
MERGE (x)-[:REL]->(r:Relation {type: row.type})-[:REL]->(y)
MERGE (d)-[:MENTIONS]->(r)
""", {'data': predicted_rels})
