# [End-to-End Application] Building Knowledge Graphs with Rebel: 
## Extract, Visualize and Ingest the triplets into Neo4j Database

![image.png](attachment:486ae067-65d2-41cb-98a6-72c3f382dd4f.png)


![image.png](attachment:15a65bb3-206f-4ade-b8b4-fb701042a4c1.png)

## Import packages

In [1]:
!pip install transformers neo4j langchain > /dev/null

In [19]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import math
import torch
import IPython
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


## Load the REBEL model

In [20]:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [21]:
test_text = "This notebook shows how to implement a pipeline for extracting a Knowledge Base from texts or documents."

In [22]:
# Tokenizer text
model_inputs = tokenizer(test_text,
                          max_length=512,
                          padding=True,
                          truncation=True,
                        return_tensors='pt')

print(f"Num tokens: {len(model_inputs['input_ids'][0])}")

# Generate
gen_kwargs = {
    "max_length": 216,
    "length_penalty": 0,
    "num_beams": 5,
    "num_return_sequences": 4
}
generated_tokens = model.generate(
    **model_inputs,
    **gen_kwargs,
)
decoded_preds = tokenizer.batch_decode(generated_tokens,
                                        skip_special_tokens=False)

decoded_preds

Num tokens: 20


['<s><triplet> documents <subj> texts <obj> has part</s><pad>',
 '<s><triplet> Knowledge Base <subj> pipeline <obj> instance of</s>',
 '<s><triplet> Knowledge Base <subj> pipeline <obj> use</s><pad>',
 '<s><triplet> document <subj> texts <obj> has part</s><pad>']

In [23]:
def extract_relations_from_model_output(text):
    relations = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
    for token in text_replaced.split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        relations.append({
            'head': subject.strip(),
            'type': relation.strip(),
            'tail': object_.strip()
        })
    return relations

In [24]:
class KB():
    def __init__(self):
        self.relations = []

    def are_relations_equal(self, r1, r2):
        return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])

    def exists_relation(self, r1):
        return any(self.are_relations_equal(r1, r2) for r2 in self.relations)

    def add_relation(self, r):
        if not self.exists_relation(r):
            self.relations.append(r)

    def print(self):
        print("Relations:")
        for r in self.relations:
            print(f"  {r}")

In [25]:
def from_small_text_to_kb(text, verbose=False):
    kb = KB()

    # Tokenizer text
    model_inputs = tokenizer(text,
                             max_length=512,
                             padding=True,
                             truncation=True,
                            return_tensors='pt')
    if verbose:
        print(f"Num tokens: {len(model_inputs['input_ids'][0])}")

    # Generate
    gen_kwargs = {
        "max_length": 216,
        "length_penalty": 0,
        "num_beams": 3,
        "num_return_sequences": 3
    }
    generated_tokens = model.generate(
        **model_inputs,
        **gen_kwargs,
    )
    decoded_preds = tokenizer.batch_decode(generated_tokens,
                                           skip_special_tokens=False)

    # create kb
    for sentence_pred in decoded_preds:
        relations = extract_relations_from_model_output(sentence_pred)
        for r in relations:
            kb.add_relation(r)

    return kb

In [26]:
small_kb = from_small_text_to_kb(test_text)

In [27]:
small_kb.print()

Relations:
  {'head': 'documents', 'type': 'has part', 'tail': 'texts'}
  {'head': 'Knowledge Base', 'type': 'instance of', 'tail': 'pipeline'}
  {'head': 'texts', 'type': 'part of', 'tail': 'documents'}


## Split spans: from long text to KB

In [28]:
class SpankB():
    def __init__(self):
        self.relations = []

    def are_relations_equal(self, r1, r2):
        return all(r1[attr] == r2[attr] for attr in ["head",
                                                     "type",
                                                     "tail"])

    def exists_relation(self, r1):
        return any(self.are_relations_equal(r1, r2) for r2 in self.relations)

    def merge_relations(self, r1):
        r2 = [r for r in self.relations
              if self.are_relations_equal(r1, r)][0]
        spans_to_add = [span for span in r1["meta"]["spans"]
                        if span not in r2["meta"]["spans"]]
        r2["meta"]["spans"] += spans_to_add

    def add_relation(self, r):
        if not self.exists_relation(r):
            self.relations.append(r)
        else:
            self.merge_relations(r)

    def print(self):
        print("Relations:")
        for r in self.relations:
            print(f"  {r}")

    def save_csv(self,file_name):
        print(f"Saving to file {file_name}")
        reln_df = pd.DataFrame(self.relations)
        reln_df.to_csv(file_name,index=False)

In [29]:
def from_text_to_kb(text,
                    span_length=50,
                    verbose=False):
    # tokenize whole text
    inputs = tokenizer([text], return_tensors="pt")

    # compute span boundaries
    num_tokens = len(inputs["input_ids"][0])
    if verbose:
        print(f"Input has {num_tokens} tokens")
    num_spans = math.ceil(num_tokens / span_length)
    if verbose:
        print(f"Input has {num_spans} spans")
    overlap = math.ceil((num_spans * span_length - num_tokens) /
                        max(num_spans - 1, 1))
    spans_boundaries = []
    start = 0
    for i in range(num_spans):
        spans_boundaries.append([start + span_length * i,
                                 start + span_length * (i + 1)])
        start -= overlap
    if verbose:
        print(f"Span boundaries are {spans_boundaries}")

    # transform input with spans
    tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
                  for boundary in spans_boundaries]
    tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
                    for boundary in spans_boundaries]
    inputs = {
        "input_ids": torch.stack(tensor_ids),
        "attention_mask": torch.stack(tensor_masks)
    }

    # generate relations
    num_return_sequences = 3
    gen_kwargs = {
        "max_length": 256,
        "length_penalty": 0,
        "num_beams": 3,
        "num_return_sequences": num_return_sequences
    }
    generated_tokens = model.generate(
        **inputs,
        **gen_kwargs,
    )

    # decode relations
    decoded_preds = tokenizer.batch_decode(generated_tokens,
                                           skip_special_tokens=False)

    # create kb
    kb = SpankB()
    i = 0
    for sentence_pred in decoded_preds:
        current_span_index = i // num_return_sequences
        relations = extract_relations_from_model_output(sentence_pred)
        for relation in relations:
            relation["meta"] = {
                "spans": [spans_boundaries[current_span_index]]
            }
            kb.add_relation(relation)
        i += 1

    return kb

In [30]:
long_text = """Knowledge Bases and Knowledge Graphs
A Knowledge Base (KB) is information stored as structured data, ready to be used for analysis or inference. Usually, a KB is stored as a graph (i.e. a Knowledge Graph), where nodes are entities and edges are relations between entities.

For example, from the text “Fabio lives in Italy” we can extract the relation triplet <Fabio, lives in, Italy>, where “Fabio” and “Italy” are entities.

Extracting relation triplets from raw text is a crucial task in Information Extraction, enabling multiple applications such as populating or validating knowledge bases, fact-checking, and other downstream tasks.

How to build a Knowledge Graph
To build a knowledge graph from text, we typically need to perform two steps:

Extract entities, a.k.a. Named Entity Recognition (NER), which are going to be the nodes of the knowledge graph.
Extract relations between the entities, a.k.a. Relation Classification (RC), which are going to be the edges of the knowledge graph.
These multiple-step pipelines often propagate errors or are limited to a small number of relation types. Recently, end-to-end approaches have been proposed to tackle both tasks simultaneously. This task is usually referred to as Relation Extraction (RE). In this article, we’ll use an end-to-end model called REBEL, from the paper Relation Extraction By End-to-end Language generation.

How REBEL works
REBEL is a text2text model trained by BabelScape by fine-tuning BART for translating a raw input sentence containing entities and implicit relations into a set of triplets that explicitly refer to those relations. It has been trained on more than 200 different relation types.

The authors created a custom dataset for REBEL pre-training, using entities and relations found in Wikipedia abstracts and Wikidata, and filtering them using a RoBERTa Natural Language Inference model (similar to this model). Have a look at the paper to know more about the creation process of the dataset. The authors also published their dataset on the Hugging Face Hub.

The model performs quite well on an array of Relation Extraction and Relation Classification benchmarks.

You can find REBEL in the Hugging Face Hub."""

In [31]:
kb = from_text_to_kb(long_text,
                     verbose=True)

Input has 508 tokens
Input has 11 spans
Span boundaries are [[0, 50], [45, 95], [90, 140], [135, 185], [180, 230], [225, 275], [270, 320], [315, 365], [360, 410], [405, 455], [450, 500]]


In [32]:
kb.print()

Relations:
  {'head': 'Knowledge Base', 'type': 'subclass of', 'tail': 'graph', 'meta': {'spans': [[0, 50]]}}
  {'head': 'Knowledge Base', 'type': 'subclass of', 'tail': 'structured data', 'meta': {'spans': [[0, 50]]}}
  {'head': 'nodes', 'type': 'part of', 'tail': 'Knowledge Graph', 'meta': {'spans': [[45, 95]]}}
  {'head': 'edges', 'type': 'part of', 'tail': 'Knowledge Graph', 'meta': {'spans': [[45, 95]]}}
  {'head': 'Knowledge Graph', 'type': 'has parts of the class', 'tail': 'nodes', 'meta': {'spans': [[45, 95]]}}
  {'head': 'Knowledge Graph', 'type': 'has parts of the class', 'tail': 'edges', 'meta': {'spans': [[45, 95]]}}
  {'head': 'relation triplets', 'type': 'subclass of', 'tail': 'entities', 'meta': {'spans': [[90, 140]]}}
  {'head': 'relation triplet', 'type': 'subclass of', 'tail': 'entities', 'meta': {'spans': [[90, 140]]}}
  {'head': 'Extracting relation triplets', 'type': 'use', 'tail': 'Information Extraction', 'meta': {'spans': [[90, 140]]}}
  {'head': 'fact-checking'

## Visualize the KG


In [33]:
kb.save_csv("../data/relations.csv")

Saving to file ../data/relations.csv


In [38]:
from pyvis.network import Network

def save_network_html(kb, filename="network.html"):
    # create network
    net = Network(directed=True, width="700px", height="700px", cdn_resources = "in_line")

    # nodes
    color_entity = "#00FF00"
    
    for e in kb.relations:
        net.add_node(e['head'], shape="circle", color=color_entity)
        net.add_node(e['tail'], shape="circle", color=color_entity)
    # edges
    for r in kb.relations:
        net.add_edge(r["head"], r["tail"],
                    title=r["type"], label=r["type"])

    # save network
    net.repulsion(
        node_distance=200,
        central_gravity=0.2,
        spring_length=200,
        spring_strength=0.05,
        damping=0.09
    )
    net.set_edge_smooth('dynamic')
    net.show('../networks/' + filename, notebook=False)


In [None]:
save_network_html(kb, filename="network.html")

In [1]:

from IPython.display import HTML

html_file = "../networks/network.html"
display(HTML(html_file))


## Ingest the triplets into Neo4j Database

In [2]:
def sanitize(text):
    text = str(text).replace("'","").replace('"','').replace('{','').replace('}', '').replace(" ", "_").replace('-', '_')
    return text


In [4]:
import pandas as pd
df = pd.read_csv('../data/relations.csv')
df = df.drop(columns={'meta'})
for col in df.columns:
    df[col] = df[col].apply(sanitize).str.lower()
df

Unnamed: 0,head,type,tail
0,nodes,part_of,graph
1,entities,part_of,graph
2,node,part_of,graph
3,nodes,opposite_of,edges
4,edges,opposite_of,nodes
5,nodes,opposite_of,edge
6,named_entity_recogn,subclass_of,entities
7,named_entity_recogn,instance_of,entities
8,named_entity_recogn,use,knowledge_graph
9,named_entity_recognition,subclass_of,relation_classification


In [7]:

from neo4j import GraphDatabase


class Neo4jConnection:
    
    def __init__(self, uri, user, pwd):
        self.__uri = uri
        self.__user = user
        self.__pwd = pwd
        self.__driver = None
        try:
            self.__driver = GraphDatabase.driver(self.__uri, auth=(self.__user, self.__pwd))
        except Exception as e:
            print("Failed to create the driver:", e)
        
    def close(self):
        if self.__driver is not None:
            self.__driver.close()
        
    def query(self, query, db=None):
        assert self.__driver is not None, "Driver not initialized!"
        session = None
        response = None
        try: 
            session = self.__driver.session(database=db) if db is not None else self.__driver.session() 
            response = list(session.run(query))
        except Exception as e:
            print("Query failed:", e)
        finally: 
            if session is not None:
                session.close()
        return response

In [8]:
uri = "bolt://localhost:7687"
username = "neo4j"
password = "neo4j080503"
database = 'demo'

conn = Neo4jConnection(uri=uri, user=username, pwd=password)

# Loop through data and create Cypher query
for i in range(len(df['head'])):

    query = f'''
        MERGE (head:Head {{name: "{df['head'][i]}"}})

        MERGE (tail:tail {{value: "{df['tail'][i]}"}})

        MERGE (head)-[:{df['type'][i]}]->(tail)
        '''
    result = conn.query(query, db=database)
print('Data Ingested to Neo4j.')

In [10]:
# ## you can also use Neo4jGraph with Langchain to ingest the data to Neo4j DB.
# from langchain.graphs import Neo4jGraph

# graph = Neo4jGraph(
#     url="bolt://localhost:7687", 
#     username="neo4j", 
#     password="neo4j080503",
#     database="demo"
# )


# # Loop through data and create Cypher query
# for i in range(len(df['head'])):

#     query = f'''
#         MERGE (head:Head {{name: "{df['head'][i]}"}})

#         MERGE (tail:tail {{value: "{df['tail'][i]}"}})

#         MERGE (head)-[:{sanitize(df['type'][i])}]->(tail)
#         '''
#     # print(query)
#     graph.query(query)