In [4]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import math
import torch

import wikipedia
from newspaper import Article, ArticleException
from GoogleNews import GoogleNews

from pyvis.network import Network
import IPython

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")

In [6]:
# from https://huggingface.co/Babelscape/rebel-large
def extract_relations_from_model_output(text):
    relations = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
    for token in text_replaced.split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        relations.append({
            'head': subject.strip(),
            'type': relation.strip(),
            'tail': object_.strip()
        })
    return relations

In [23]:
class KnowledgeBase():
    def __init__(self):
        self.relations = []
        self.entities = {}
        self.sources = {}

    def are_relations_equal(self, r1, r2):
        return all(r1[attr] == r2[attr] for attr in ['head', 'type', 'tail'])
    
    def relation_exists(self, r1):
        return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
    
    def add_relations(self, r, article_title, article_publish_date):
        candidate_entities = [r['head'], r['tail']]
        entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]

        if any(ent is None for ent in entities):
            return
        
        for e in entities:
            self.add_entity(e)
            
        r['head'] = entities[0]['title']
        r['tail'] = entities[1]['title']

        article_url = list(r['meta'].keys())[0]
        if article_url not in self.sources:
            self.sources[article_url] = {
                'article_title': article_title,
                'article_publish_date': article_publish_date
            }

        if not self.relation_exists(r):
            self.relations.append(r)

        else:
            self.merge_relations(r)
            
    def print(self):
        print("Entities:")
        for e in self.entities.items():
            print(f"  {e}")
        print("Relations:")
        for r in self.relations:
            print(f"  {r}")
        print("Sources:")
        for s in self.sources.items():
            print(f"  {s}")

    def merge_with_kb(self, kb2):
        for r in kb2.relations:
            article_url = list(r['meta'].keys())[0]
            source_data = kb2.sources(article_url)

            self.add_relations(r, source_data['article_title'], source_data['article_publish_date'])

    def merge_relations(self, r1):
        r2 = [r for r in self.relations if self.are_relations_equal(r1, r)][0]

        article_url = list(r1['meta'].keys())[0]

        if article_url not in r2['meta']:
            r2['meta'][article_url] = r1['meta'][article_url]

        else:
            spans_to_add = [span for span in r1['meta'][article_url]['span'] 
                            if span not in r2['meta'][article_url]['span']]

            r2['meta'][article_url]['span'] += spans_to_add

    def get_wikipedia_data(self, candidate_entity):
        try:
            page = wikipedia.page(candidate_entity)

            entity_data = {
                "title": page.title,
                "url": page.url,
                "summary": page.summary
            }

            return entity_data
        except:
            return None
        
    def add_entity(self, e):
        self.entities[e['title']] = {k:v for k,v in e.items() if k != 'title'}
    
            
        

In [24]:
def from_small_text_to_kb(text, verbose=False):
    kb = KnowledgeBase()

    model_inputs = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')

    if verbose:
        print(f"Num tokens: {len(model_inputs['input_ids'][0])}")
        
    kwargs = {
        "max_length": 216,
        "length_penalty": 0,
        "num_beams": 3,
        "num_return_sequences": 3
    }   

    generated_tokens = model.generate(**model_inputs, **kwargs)

    decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)

    for sentence_pred in decoded_preds:
        relations = extract_relations_from_model_output(sentence_pred)
        for r in relations:
            kb.add_relations(r)

    return kb

In [25]:
def from_text_to_kb(text, article_url, span_length=128, article_title=None,
                    article_publish_date=None, verbose=False):
    inputs = tokenizer([text], return_tensors='pt')

    num_tokens = len(inputs['input_ids'][0])

    if verbose:
        print(f"Input has {num_tokens} tokens")

    num_spans = math.ceil(num_tokens / span_length)
    if verbose:
        print(f"Input has {num_spans} spans")

    overlap = math.ceil((num_spans * span_length - num_tokens) / max(num_spans-1, 1))

    spans_boundaries = []
    start = 0

    for i in range(num_spans):
        spans_boundaries.append([start + span_length * i, start + span_length * (i+1)])

        start -= overlap

    if verbose:
        print(f"Span boundaries are {spans_boundaries}")
    
    tensor_ids = [inputs['input_ids'][0][boundary[0]:boundary[1]] for boundary in spans_boundaries]
    tensor_masks = [inputs['attention_mask'][0][boundary[0]:boundary[1]] for boundary in spans_boundaries]

    inputs = {
        "input_ids": torch.stack(tensor_ids),
        "attention_mask": torch.stack(tensor_masks)
    }

    num_return_sequences = 3
    kwargs = {
        "max_length": 256,
        "length_penalty": 0,
        "num_beams": 3,
        "num_return_sequences": num_return_sequences
    }

    generated_tokens = model.generate(**inputs, **kwargs)

    decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)

    kb = KnowledgeBase()

    i = 0

    for sentence_pred in decoded_preds:
        current_span_index = i // num_return_sequences
        relations = extract_relations_from_model_output(sentence_pred)

        for relation in relations:
            relation['meta'] = {
                article_url : {
                    "span": [spans_boundaries[current_span_index]]
                }
            }

            kb.add_relations(relation, article_title, article_publish_date)

        i += 1

    return kb

In [None]:
def get_article(url):
    article = Article(url)
    article.download()
    article.parse()

    return article

def from_url_to_kb(url):
    article = get_article(url)
    config = {
        "article_title": article.title,
        "article_publish_date": article.publish_date
    }

    kb = from_text_to_kb(article.title, article.url, **config)

    return kb

In [None]:
def get_news_links(query, lang='en', region='US', pages=1, max_links=100000):
    googlenews = GoogleNews(lang=lang, region=region)
    googlenews.search(query)

    all_urls = []

    for page in range(pages):
        googlenews.get(page)
        all_urls += googlenews.get_links()

    return list(set(all_urls))[:max_links]

def from_urls_to_kb(urls, verbose=False):
    kb = KnowledgeBase()
    if verbose:
        print(f"{len(urls)} links to visit")

    for url in urls:
        if verbose:
            print(f"Visiting {url}...")

        try:
            kb_url = from_url_to_kb(url)
            kb.merge_with_kb(kb_url)

        except ArticleException:
            if verbose:
                print(f"  Couldn't download article at url {url}")
    return kb

In [26]:
text = "Napoleon Bonaparte (born Napoleone di Buonaparte; 15 August 1769 – 5 " \
"May 1821), and later known by his regnal name Napoleon I, was a French military " \
"and political leader who rose to prominence during the French Revolution and led " \
"several successful campaigns during the Revolutionary Wars. He was the de facto " \
"leader of the French Republic as First Consul from 1799 to 1804. As Napoleon I, " \
"he was Emperor of the French from 1804 until 1814 and again in 1815. Napoleon's " \
"political and cultural legacy has endured, and he has been one of the most " \
"celebrated and controversial leaders in world history."

# kb = from_small_text_to_kb(text, verbose=True)

kb = from_text_to_kb(text, verbose=True)
kb.print()

Input has 133 tokens
Input has 2 spans
Span boundaries are [[0, 128], [5, 133]]


In [1]:
def save_network_html(kb, filename='network.html'):
    net = Network(directed=True, width='auto', height="700px", bgcolor="#eeeeee")

    color_entity = '#00FF00'

    for e in kb.entities:
        net.add_node(e, shape='circle', color=color_entity)
        
    for r in kb.relations:
        net.add_edge(r['head'], r['tail'], title=r['type'], label=r['type'])

    net.repulsion(
        node_distance=200,
        central_gravity=0.2,
        spring_length=200,
        spring_strength=0.05,
        damping=0.09
    )
    net.set_edge_smooth('dynamic')
    net.show(filename)

In [None]:
news_links = get_news_links("Google", pages=5, max_links=20)
kb = from_urls_to_kb(news_links, verbose=True)
filename = "network_3_google.html"
save_network_html(kb, filename=filename)