# Pipeline 1

In [62]:
from pipeline import SerialAnnotator, T5Converter
import time
from pprint import pprint

In [63]:
annotator = SerialAnnotator()
converter = T5Converter()

In [64]:
def pipe_batch(linked, wikisparqls):
  batched = []
  s = time.time()
  annotated = annotator.batch_annotate(linked)
  e = time.time()
  assert len(linked) == len(annotated)
  annotator_time = e - s
  converter_time = 0
  for i, single_annotated in enumerate(annotated):
    # annotated = annotator.annotate(**single_linked)
    # try:
    #   annotator_time += e - s
    # except Exception as e:
    #   print("[Annotator Error]:", e)
    #   continue
    s = time.time()
    try:
      converted = converter.preprocess(**single_annotated, wikisparql=wikisparqls[i])
    except Exception as e:
      print("[Converter Error]:", e)
      continue
      
    e = time.time()
    converter_time += e - s
    batched.append([linked[i], annotated[i], converted])
  print("Anno batch time:", annotator_time)
  print("Conv batch time:", converter_time)
  return batched

## Data

In [65]:
from pathlib import Path
from pprint import pprint
import pandas as pd
import json

# data_path = Path("..") / "t5-for-sparql" / "data" / "lcquad2" / "train.json"
data_path = Path("..") / "t5-for-sparql" / "falcon_links" / "top1ents" / "link_24066.json"

In [66]:
with open(data_path) as f:
  data_json = json.load(f)
sample = data_json[1]

In [67]:
import re

def retrieve_gold_links(masked_wikisparql):
  fragments = masked_wikisparql.upper().strip().split(' ')
  ents = []
  rels = []
  result_dict = {
    'ents': ents,
    'rels': rels,
  }
  for fragment in fragments:
    id_match = re.search('[QP][1-9]+',fragment)
    if id_match:
      id_raw = id_match.group(0)
      if id_raw[0] == 'Q':
        prefix = 'wdt:'
        uri = 'http://www.wikidata.org/entity/' + id_raw
        target_arr = ents
      elif id_raw[0] == 'P':
        prefix = 'wd:'
        uri =  'http://www.wikidata.org/prop/direct/' + id_raw
        target_arr = rels
      else:
        raise ValueError('ID does not start with P or Q')
      match_dict = {
        'id' : id_raw,
        'prefix' : prefix,
        'uri' : uri,
      }
      target_arr.append(match_dict)
  return result_dict

In [68]:
for entry in data_json:
  old_linked, _, old_converted = entry
  wikisparql = old_converted['labels']
  gold_links = retrieve_gold_links(wikisparql)
  entry.append(gold_links)

In [69]:
data_json[0]

[{'utterance': 'What is Delta Air Lines periodical literature mouthpiece?',
  'ents': [{'uri': 'http://www.wikidata.org/entity/Q188920',
    'prefix': 'wd:',
    'id': 'Q188920'}],
  'rels': []},
 {'utterance': 'What is Delta Air Lines periodical literature mouthpiece?',
  'fragments': ['[DEF]', 'wd:', 'Q188920 Delta']},
 {'inputs': 'What is Delta Air Lines periodical literature mouthpiece? <extra_id_59> <extra_id_53> Q188920 Delta',
  'labels': '<extra_id_6> <extra_id_21> <extra_id_39> <extra_id_19> <extra_id_33> <extra_id_53> q188920 <extra_id_54> p2813 <extra_id_39> <extra_id_38> <extra_id_39> <extra_id_54> p31 <extra_id_53> q1002697 <extra_id_15>'},
 {'ents': [{'id': 'Q18892',
    'prefix': 'wdt:',
    'uri': 'http://www.wikidata.org/entity/Q18892'},
   {'id': 'Q1', 'prefix': 'wdt:', 'uri': 'http://www.wikidata.org/entity/Q1'}],
  'rels': [{'id': 'P2813',
    'prefix': 'wd:',
    'uri': 'http://www.wikidata.org/prop/direct/P2813'},
   {'id': 'P31',
    'prefix': 'wd:',
    'uri': '

In [70]:
responses = []
batch_new_linked = []
batch_ans = []
last = 0
cutoff = 0
total_len = len(data_json)
for i, data in enumerate(data_json):
    linked, annotated, converted, gold_linked = data
    if i < cutoff:
        continue
    question = linked['utterance']
    new_ents = linked['ents']
    old_ents_set = set(x['id'] for x in linked['ents'])
    for gold_ent in gold_linked['ents']:
        if gold_ent['id'] not in old_ents_set:
            new_ents.append(gold_ent)
    new_linked = {
        'utterance': question,
        'ents': new_ents,
        'rels':[],
    }
    answer = converted['labels']
    batch_new_linked.append(new_linked)
    batch_ans.append(answer)
    if ((i + 1) % 50) == 0 or i == total_len - 1:
        print("[Pipeline2]:", f"Processing {last}-{i}")
        try:
            for linked, annotated, converted in pipe_batch(batch_new_linked, batch_ans):
                responses.append([linked, annotated, converted])
            batch_new_linked = []
            batch_ans = []
            last = i
        except Exception as e:
            raise e
            # print(e)
            # continue
        with open(f"../t5-for-sparql/falcon_links/5ents-gold_0rels/link_{i}.json", "w") as f:
            json.dump(responses, f, indent=2, separators=(',',':'))

[Pipeline2]: Processing 0-49
