In [1]:
import transformers
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import math
import torch
import wikipedia
from newspaper import Article, ArticleException
from GoogleNews import GoogleNews
import IPython
from pyvis.network import Network

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import requests
import json
def generate_context(query:str, num_urls:int):
    
    response= requests.post("https://qagen.paperbot.ai/extract_all_passages/", json={
                                                                                    "query": query,
                                                                                    "num_urls": int(num_urls),
                                                                                    } )
    
    if response.ok:
        # d= eval(response.content)
        paragrahs= json.loads(response.content.decode(
                                                        'utf-8'
                                                    ))['paragraphs']
        return paragrahs
        
    else:
        print("Couldn't get the response from the 'extract-all-passages'   🥲")
        return 

In [3]:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large").to("cuda")

In [4]:
def extract_relations_from_model_output(text):
    relations = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
    for token in text_replaced.split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                relations.append({
                    'subject': subject.strip(),
                    'verb': relation.strip(),
                    'object': object_.strip()
                })
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                relations.append({
                    'subject': subject.strip(),
                    'verb': relation.strip(),
                    'object': object_.strip()
                })
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        relations.append({
            'subject': subject.strip(),
            'verb': relation.strip(),
            'object': object_.strip()
        })
    return relations

In [5]:
class KB():
    def __init__(self):
        self.relations = []

    def are_relations_equal(self, r1, r2):
        return all(r1[attr] == r2[attr] for attr in ["subject", "verb", "object"])

    def exists_relation(self, r1):
        return any(self.are_relations_equal(r1, r2) for r2 in self.relations)

    # def add_relation(self, r):
    #     if not self.exists_relation(r):
    #         self.relations.append(r)

    def print(self):
        print("Relations:")
        for r in self.relations:
            print(f"  {r}")

    def merge_relations(self, r1):
        r2 = [r for r in self.relations
              if self.are_relations_equal(r1, r)][0]
        spans_to_add = [span for span in r1["meta"]["spans"]
                        if span not in r2["meta"]["spans"]]
        r2["meta"]["spans"] += spans_to_add

    def add_relation(self, r):
        if not self.exists_relation(r):
            self.relations.append(r)
        else:
            self.merge_relations(r)

In [6]:

# def from_text_to_kb(text, verbose=False):
#     kb = KB()

#     # Tokenizer text
#     model_inputs = tokenizer(text, max_length=512, padding=True, truncation=True,
#                             return_tensors='pt')
#     if verbose:
#         print(f"Num tokens: {len(model_inputs['input_ids'][0])}")

#     # Generate
#     gen_kwargs = {
#         "max_length": 216,
#         "length_penalty": 0,
#         "num_beams": 3,
#         "num_return_sequences": 3
#     }
#     generated_tokens = model.generate(
#         **model_inputs.to(model.device),
#         **gen_kwargs,
#     )
#     decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)

#     # create kb
#     for sentence_pred in decoded_preds:
#         relations = extract_relations_from_model_output(sentence_pred)
#         for r in relations:
#             kb.add_relation(r)
#     print(f"relations generated: {len(kb.relations)}")
    

#     return kb

In [7]:
def from_text_to_kb(text, span_length=128, verbose=False):
    # tokenize whole text
    inputs = tokenizer([text], return_tensors="pt")
    text= text.lower()

    # compute span boundaries
    num_tokens = len(inputs["input_ids"][0])
    if verbose:
        print(f"Input has {num_tokens} tokens")
    num_spans = math.ceil(num_tokens / span_length)
    if verbose:
        print(f"Input has {num_spans} spans")
    overlap = math.ceil((num_spans * span_length - num_tokens) / 
                        max(num_spans - 1, 1))
    spans_boundaries = []
    start = 0
    for i in range(num_spans):
        spans_boundaries.append([start + span_length * i,
                                 start + span_length * (i + 1)])
        start -= overlap
    if verbose:
        print(f"Span boundaries are {spans_boundaries}")

    # transform input with spans
    tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]]
                  for boundary in spans_boundaries]
    tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]]
                    for boundary in spans_boundaries]
    inputs = {
        "input_ids": torch.stack(tensor_ids),
        "attention_mask": torch.stack(tensor_masks)
    }
    inputs= transformers.tokenization_utils_base.BatchEncoding(inputs)

    # generate relations
    num_return_sequences = 3
    gen_kwargs = {
        "max_length": 256,
        "length_penalty": 0,
        "num_beams": 3,
        "num_return_sequences": num_return_sequences
    }
    generated_tokens = model.generate(
        **inputs.to(model.device),
        **gen_kwargs,
    )

    # decode relations
    decoded_preds = tokenizer.batch_decode(generated_tokens,
                                           skip_special_tokens=False)

    # create kb
    kb = KB()
    i = 0
    for sentence_pred in decoded_preds:
        current_span_index = i // num_return_sequences
        relations = extract_relations_from_model_output(sentence_pred)
        for relation in relations:
            relation["meta"] = {
                "spans": [spans_boundaries[current_span_index]]
            }
            kb.add_relation(relation)
        i += 1
    print(f"relations generated: {len(kb.relations)}")
    # print(kb.relations)
    return kb

In [8]:
para= generate_context("best cat ear headphones", 10)


In [9]:
para[:2]
len(para)


930

In [11]:
kb= KB()

print(f"***************  Total Paragraphs: {len(para)}  ***************")

i= 1
for p in para:
    print(f"paragraph {i}:")
    i+=1
    # print()
    temp_kb= from_text_to_kb(p.lower(), span_length=10,verbose=True)
    kb.relations+=(temp_kb.relations)
    print()
    
    

***************  Total Paragraphs: 930  ***************
paragraph 1:
Input has 24 tokens
Input has 3 spans
Span boundaries are [[0, 10], [7, 17], [14, 24]]


relations generated: 6

paragraph 2:
Input has 36 tokens
Input has 4 spans
Span boundaries are [[0, 10], [8, 18], [16, 26], [24, 34]]
relations generated: 10

paragraph 3:
Input has 80 tokens
Input has 8 spans
Span boundaries are [[0, 10], [10, 20], [20, 30], [30, 40], [40, 50], [50, 60], [60, 70], [70, 80]]
relations generated: 23

paragraph 4:
Input has 62 tokens
Input has 7 spans
Span boundaries are [[0, 10], [8, 18], [16, 26], [24, 34], [32, 42], [40, 50], [48, 58]]
relations generated: 21

paragraph 5:
Input has 87 tokens
Input has 9 spans
Span boundaries are [[0, 10], [9, 19], [18, 28], [27, 37], [36, 46], [45, 55], [54, 64], [63, 73], [72, 82]]
relations generated: 23

paragraph 6:
Input has 50 tokens
Input has 5 spans
Span boundaries are [[0, 10], [10, 20], [20, 30], [30, 40], [40, 50]]
relations generated: 15

paragraph 7:
Input has 72 tokens
Input has 8 spans
Span boundaries are [[0, 10], [8, 18], [16, 26], [24, 34], [32, 42], [40, 50], [48, 58], [56, 66]]
relations generated

In [12]:
len(kb.relations)


12465

In [13]:
import os; os.chdir('..')

In [14]:
json.dump(kb.relations, open("dummy_outputs/relation.json", 'w'))
