In [76]:
import PyPDF2
import os
import itertools
import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import math

In [4]:
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")

Downloading model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

# Preprocessing

In [5]:
file_names = os.listdir('./pdfs/')

In [8]:
pdfs = [os.path.join("./pdfs/", i) for i in file_names]

In [26]:
# for file in pdfs:
with open(pdfs[0], 'rb') as pdf:
    text = ''
    reader = PyPDF2.PdfReader(pdf)
    for page in reader.pages:
        text += page.extract_text()
            

In [67]:
num_pages = 3525
tbook_len = len(text)
n_segments = tbook_len//num_pages

# Tt seems that each disease is more or less on a single page, so rather than try and parse the information in the text using 
# the table of contents, I tried to extract each page with the idea that I could build NER per page. 

In [32]:
def split_text(text, num_segments):
    chunk_size = len(text)//num_segments
    chunks = []
    start = 0
    for i in range(num_segments):
        end = start + chunk_size
        chunks.append(text[start:end])
        start = end
    chunks.append(text[start:])
    return chunks

In [33]:
segments = split_text(text, n_segments)

## NER Triplet Extraction and building the graph
Given the large size of the texbook, I tried to extract triplets per page (one disease) at a time, and then build a graph based on that page. The process would iterate over each page, building a graph per page (while not adding exisiting entities) and iteratively combining the graphs until all the pages had been iterated over. 

In [35]:
def extract_relations_from_model_output(text):
    relations = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
    for token in text_replaced.split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                relations.append({
                    'head': subject.strip(),
                    'type': relation.strip(),
                    'tail': object_.strip()
                })
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        relations.append({
            'head': subject.strip(),
            'type': relation.strip(),
            'tail': object_.strip()
        })
    return relations

In [73]:
class KB():
    def __init__(self):
        self.relations = []

    def are_relations_equal(self, r1, r2):
        return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])

    def exists_relation(self, r1):
        return any(self.are_relations_equal(r1, r2) for r2 in self.relations)

    def add_relation(self, r):
        if not self.exists_relation(r):
            self.relations.append(r)
            
    def get_entities(self, e):
        return set([*[self.relations[i]['tail'] for i in range(len(self.relations))], *[self.relations[i]['head'] for i in range(len(self.relations))]])
            
    def merge_relations(self, r1):
        r2 = [r for r in self.relations
              if self.are_relations_equal(r1, r)][0]
        spans_to_add = [span for span in r1["meta"]["spans"]
                        if span not in r2["meta"]["spans"]]
        r2["meta"]["spans"] += spans_to_add

    def add_relation(self, r):
        if not self.exists_relation(r):
            self.relations.append(r)
        else:
            self.merge_relations(r)


In [77]:
def from_text_to_kb(text, span_length=128):
    inputs = tokenizer([text], return_tensors="pt")

    num_tokens = len(inputs["input_ids"][0])
    
    num_spans = math.ceil(num_tokens / span_length)
    
    overlap = math.ceil((num_spans * span_length - num_tokens) / 
                        max(num_spans - 1, 1))
    spans_boundaries = []
    start = 0
    for i in range(num_spans):
        spans_boundaries.append([start + span_length * i,
                                 start + span_length * (i + 1)])
        start -= overlap

    tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries]
    tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries]
    inputs = {"input_ids": torch.stack(tensor_ids),
              "attention_mask": torch.stack(tensor_masks)}

    num_return_sequences = 3
    gen_kwargs = {"max_length": 256,
                "length_penalty": 0,
                "num_beams": 3,
                "num_return_sequences": num_return_sequences}
    
    generated_tokens = model.generate(**inputs,
                                      **gen_kwargs,)

    decoded_preds = tokenizer.batch_decode(generated_tokens,
                                           skip_special_tokens=False)

    kb = KB()
    i = 0
    for sentence_pred in decoded_preds:
        current_span_index = i // num_return_sequences
        relations = extract_relations_from_model_output(sentence_pred)
        for relation in relations:
            relation["meta"] = {
                "spans": [spans_boundaries[current_span_index]]
            }
            kb.add_relation(relation)
        i += 1

    entities = kb.get_entities()
    
    return kb.relations, entities

In [None]:
# WARNING this will take a long time to run

all_relations = []
all_entities = []
for segment in segments:
    relations, entities = from_text_to_kb(segment)
    all_relations.append(relations)
    all_entities.append(entities)
    
final_relations = list(set(itertools.chain(*all_relations)))
final_entities = list(set(itertools.chain(*all_entities)))

## Visualization

In [42]:
from pyvis.network import Network

In [53]:
def save_network_html(relations, entities, filename="network.html"):
    net = Network(directed=True, width="auto", height="700px", bgcolor="#eeeeee")
    color_entity = "#00FF00"
    for e in entities: #kb.entities:
        net.add_node(e, shape="circle", color=color_entity)

    for r in relations:
        net.add_edge(r["head"], r["tail"],
                    title=r["type"], label=r["type"])

    net.repulsion(
        node_distance=200,
        central_gravity=0.2,
        spring_length=200,
        spring_strength=0.05,
        damping=0.09
    )
    net.set_edge_smooth('dynamic')
    net.show(filename)

In [62]:
filename = "kg.html"
save_network_html(final_relations, final_entities, filename=filename)