In [1]:
import requests
import pickle
import json
import os
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import as_completed
from urllib.parse import unquote
from tqdm.auto import tqdm, trange
from kilt.knowledge_source import KnowledgeSource

In [None]:
ks = KnowledgeSource()

# with open("/checkpoint/fabiopetroni/GENRE/checkpoint/GeNeRe/data/id_title.json") as f:
#     id2title = json.load(f)
#     title2id = {v: k for k, v in id2title.items()}

In [None]:
def get_id_title(anchor, title2id):

    if "http" in anchor:
        return {"wikipedia_title": None, "wikipedia_id": None}

    unquoted = unquote(anchor).split("#")[0].replace("_", " ")
    if unquoted == "":
        return {"wikipedia_title": None, "wikipedia_id": None}

    unquoted = unquoted[0].upper() + unquoted[1:]

    if unquoted in title2id:
        wikipedia_title = unquoted
        wikipedia_id = title2id[unquoted]
        return {"wikipedia_title": wikipedia_title, "wikipedia_id": wikipedia_id}
    else:
        wikipedia_title = requests.head("https://en.wikipedia.org/wiki/{}".format(anchor),
                                        allow_redirects=True).url.split("/")[-1].split("#")[0].replace("_", " ")
        if wikipedia_title is not None:
            wikipedia_id = title2id.get(wikipedia_title, None)
            if wikipedia_id is not None:
                return {
                    "wikipedia_title": wikipedia_title,
                    "wikipedia_id": wikipedia_id,
                }

    return {"wikipedia_title": None, "wikipedia_id": None}

In [None]:
anchors = []
iter_ = tqdm(ks.get_all_pages_cursor(), total=ks.get_num_pages())
for page in iter_:
    anchors += [a['href'] for a in page["anchors"]]
    iter_.set_postfix(anchors=len(anchors), refresh=False)

In [None]:
anchors = set(anchors)

In [None]:
len(anchors)

In [None]:
with open("all_kilt_anchors.pkl", "wb") as f:
    pickle.dump(anchors, f)

In [None]:
with open("all_kilt_anchors.pkl", "rb") as f:
    anchors = pickle.load(f)

In [None]:
len(anchors) // 100

In [None]:
anchors= list(anchors)[:len(anchors) // 100]

In [None]:
num_threads = 64
with ThreadPoolExecutor(max_workers=num_threads) as executor:

    futures = {
        executor.submit(get_id_title, anchor, title2id): anchor
        for anchor in tqdm(anchors)
    }

    iter_ = tqdm(as_completed(futures), total=len(futures), smoothing=0)
    results = {futures[future]: future.result() for future in iter_}

with open("all_kilt_anchors_map.pkl", "wb") as f:
    pickle.dump(results, f)

In [None]:
results = []
for i in trange(32):
    if os.path.exists("all_kilt_anchors_map_{}.pkl".format(i)):
        with open("all_kilt_anchors_map_{}.pkl".format(i), "rb") as f:
            results += pickle.load(f).items()

results = dict(results)
with open("all_kilt_anchors_map.pkl", "wb") as f:
    pickle.dump(results, f)

In [None]:
len(results)

In [None]:
for page in tqdm(ks.get_all_pages_cursor(), total=ks.get_num_pages()):
    anchors = page["anchors"]
    for anchor in anchors:
#         if anchor["href"] in results:
        anchor["wikipedia_title"] = results[anchor["href"]]["wikipedia_title"]
        anchor["wikipedia_id"] = results[anchor["href"]]["wikipedia_id"]
#     break
    ks.db.find_one_and_update(
        {"_id": page["wikipedia_id"]}, {"$set": {"anchors": anchors}}, upsert=True,
    )

In [None]:
with open("all_kilt_anchors_map.pkl", "rb") as f:
    results = pickle.load(f)
print(len(results))

In [None]:
from collections import defaultdict

In [None]:
mention_entitiy_table = defaultdict(lambda: defaultdict(int))
for page in tqdm(ks.get_all_pages_cursor(), total=ks.get_num_pages()):
    for anchor in page["anchors"]:
        if anchor["wikipedia_title"]:
            mention_entitiy_table[anchor["text"]][anchor["wikipedia_title"]] += 1

In [None]:
mention_entitiy_table = {k: dict(v) for k, v in mention_entitiy_table.items()}
with open("mention_entitiy_table.pkl", "wb") as f:
    pickle.dump(mention_entitiy_table, f)

In [None]:
import pickle
with open("mention_entitiy_table.pkl", "rb") as f:
    mention_entitiy_table = pickle.load(f)

In [None]:
mention_entitiy_table["Carlo"]

In [None]:
len(mention_entitiy_table)

In [None]:
sum(e for v in mention_entitiy_table.values() for e in v.values())

In [None]:
import json, jsonlines

In [None]:
with open("2018.json") as f:
    views = json.load(f)

In [None]:
with jsonlines.open("/private/home/ndecao/KILT/data/nq-dev-kilt.jsonl") as f:
    data = [e for e in f]

In [None]:
views500 = set(e["article"] for e in views)

In [None]:
sum([any(p["title"] in views500 for o in d["output"]
         if "provenance" in o for p in o["provenance"]) for d in data ]) / len(data)

In [None]:
import jsonlines

In [None]:
abstracts = []
iter_ = tqdm(ks.get_all_pages_cursor(), total=ks.get_num_pages())
for page in iter_:
    for psg in page["text"][1:]:
        if "Section::::" in psg:
            break
        else:
            abstracts.append(psg)
    iter_.set_postfix(passages=len(abstracts), refresh=False)

In [None]:
def batch_iter(obj, batch_size=1):
    out = []
    for item in obj:
        if len(out) == batch_size:
            yield out
            out = []
        out.append(item)
        
    if len(out):
        yield out

In [8]:
from genre import GENRE
import jsonlines

def batch_iter(obj, batch_size=1):
    out = []
    for item in obj:
        if len(out) == batch_size:
            yield out
            out = []
        out.append(item)
        
    if len(out):
        yield out

# loading model
context2answer = (
    GENRE.from_pretrained("/checkpoint/ndecao/2020-11-03/nq_context2answer.bart_large.ls0.1.mt2048.uf4.mu20000.dr0.1.atdr0.1.actdr0.0.wd0.01.adam.beta9999.eps1e-08.clip0.1.lr3e-05.warm500.fp16.ngpu8",
                          "checkpoint43.pt")
    .eval()
    .to("cuda:1")
)

answer_context2query = (
    GENRE.from_pretrained("/checkpoint/ndecao/2020-11-03/nq_answer_context2query.bart_large.ls0.1.mt2048.uf4.mu20000.dr0.1.atdr0.1.actdr0.0.wd0.01.adam.beta9999.eps1e-08.clip0.1.lr3e-05.warm500.fp16.ngpu8",
                          "checkpoint43.pt")
    .eval()
    .to("cuda:1")
)

In [None]:
rank = 0
first_half = False
second_half = True
batch_size = 32
data = []
with jsonlines.open("/checkpoint/fabiopetroni/GENRE/checkpoint/GeNeRe/data/kilt/kilt_{}.jsonl".format(rank)) as f:
    inputs = [e for e in tqdm(batch_iter(f, batch_size))]
    
if first_half:
    inputs = inputs[:len(inputs) // 2]
elif second_half:
    inputs = inputs[len(inputs) // 2:]

In [None]:
iter_ = tqdm(inputs, smoothing=0)
for psgs in iter_:

    psgs = [psg["text"] for psg in psgs if all(e not in psg["section"].lower()
            for e in ("see also", "references", "external link", "further reading", "notes"))]
    if psgs:
        outputs_context2answer = context2answer.sample(psgs)
        for ans_psgs in batch_iter(["{} >> {}".format(answer["text"], psg)
                    for answers, psg in zip(outputs_context2answer, psgs)
                    for answer in answers], batch_size):
            outputs_context2query = [e[0]["text"] for e in answer_context2query.sample(ans_psgs)]
            for q, ac in zip(outputs_context2query, ans_psgs):
                data += [[q] + ac.split(" >> ")]
    iter_.set_postfix(data=len(data))

In [3]:



with open("qac_0_True_False.pkl", "rb") as f:
    results = pickle.load( f)

In [4]:
results

[['what is the first letter of the english alphabet',
  'A',
  'A (named , plural "As", "A\'s", "a"s, "a\'s" or "aes") is the first letter and the first vowel of the modern English alphabet and the ISO basic Latin alphabet. It is similar to the Ancient Greek letter alpha, from which it derives. The uppercase version consists of the two slanting sides of a triangle, crossed in the middle by a horizontal bar.'],
 ['what is the first letter of the english alphabet',
  'first vowel',
  'A (named , plural "As", "A\'s", "a"s, "a\'s" or "aes") is the first letter and the first vowel of the modern English alphabet and the ISO basic Latin alphabet. It is similar to the Ancient Greek letter alpha, from which it derives. The uppercase version consists of the two slanting sides of a triangle, crossed in the middle by a horizontal bar.'],
 ['what is the first letter of the english alphabet',
  'the first vowel',
  'A (named , plural "As", "A\'s", "a"s, "a\'s" or "aes") is the first letter and the f

In [10]:
answer_context2query.sample([
    "Animalia >> Animalia is an alliterative alphabet book and contains twenty-six illustrations, one for each letter of the alphabet. Each illustration features an animal from the animal kingdom (A is for alligator, B is for butterfly, etc.) along with a short poem utilizing the letter of the page for many of the words. The illustrations contain many other objects beginning with that letter that the reader can try to identify."
])

[[{'text': 'what is the book for each letter of the alphabet',
   'logprob': tensor(-0.5468, device='cuda:1')},
  {'text': 'what is the name of the alphabet book',
   'logprob': tensor(-0.7618, device='cuda:1')},
  {'text': 'what is the name of the book with a letter on it',
   'logprob': tensor(-0.8366, device='cuda:1')},
  {'text': 'what is the name of the book with all the animals',
   'logprob': tensor(-0.8450, device='cuda:1')},
  {'text': 'what is the name of the book with all the letters',
   'logprob': tensor(-0.8483, device='cuda:1')}]]

In [11]:
context2answer.sample([
    "Animalia is an alliterative alphabet book and contains twenty-six illustrations, one for each letter of the alphabet. Each illustration features an animal from the animal kingdom (A is for alligator, B is for butterfly, etc.) along with a short poem utilizing the letter of the page for many of the words. The illustrations contain many other objects beginning with that letter that the reader can try to identify."
])

[[{'text': 'Animalia', 'logprob': tensor(-0.7034, device='cuda:1')},
  {'text': 'Twenty - six', 'logprob': tensor(-0.9867, device='cuda:1')},
  {'text': 'animalia', 'logprob': tensor(-1.0954, device='cuda:1')},
  {'text': 'Alliterative', 'logprob': tensor(-1.3069, device='cuda:1')},
  {'text': 'Twenty -six', 'logprob': tensor(-1.3373, device='cuda:1')}]]