In [8]:
import requests
import json
import sys
from os.path import join
import pickle
sys.path.append('..')
from utils import *
from collections import defaultdict
from more_itertools import unique_everseen

In [9]:
DATA_PATH = '/home/rohitalyosha/Student_Job/mannheim-nel/data'

In [11]:
dis_dict = json_load('/home/rohitalyosha/Student_Job/mannheim-nel/data/dicts/disamb.json')
necounts = json_load('/home/rohitalyosha/Student_Job/mannheim-nel/data/dicts/str_necounts.json')
redirects = json_load('/home/rohitalyosha/Student_Job/mannheim-nel/data/dicts/redirects.json')
ent2id = json_load('/home/rohitalyosha/Student_Job/mannheim-nel/data/dicts/ent_dict.json')
id2ent = reverse_dict(ent2id)

In [12]:
class Doc:
    
    def __init__(self, text, mention_tups, coref=True, disamb=False, necounts=None , rd=None):
        self.text = text
        self.mentions = [Mention(text,
                                 ent_str,
                                 span, 
                                 small_context,
                                 coref=coref,
                                 disamb=disamb,
                                 necounts=necounts,
                                 rd=rd) 
                         for _, (text, ent_str, span, small_context) in mention_tups]
        self.assign_clusters()
        
    def assign_clusters(self):
        chains = []
        unchained_mentions = sorted(self.mentions, key=lambda m:m.begin, reverse=True)

        #log.debug('MENTIONS: ' + ';'.join(m.text for m in unchained_mentions))
        while unchained_mentions:
            mention = unchained_mentions.pop(0)

            potential_antecedents = [(m.text, m) for m in unchained_mentions] # if m.tag == mention.tag
            chain = [mention]

            likely_acronym = False

            if mention.text.upper() == mention.text:
                # check if our mention is an acronym of a previous mention
                for a, m in potential_antecedents:
                    if (''.join(p[0] for p in a.split(' ') if p).upper() == mention.text) or \
                       (''.join(p[0] for p in a.split(' ') if p and p[0].isupper()).upper() == mention.text):
                        chain.insert(0, m)
                        unchained_mentions.remove(m)
                        likely_acronym = True
                potential_antecedents = [(m.text, m) for m in unchained_mentions]

            last = None
            longest_mention = mention
            while last != longest_mention and potential_antecedents:
                # check if we are a prefix/suffix of a preceding mention
                n = longest_mention.text.lower()
                for a, m in potential_antecedents:
                    na = a.lower()
                    if (likely_acronym and mention.text == a) or \
                       (not likely_acronym and (na.startswith(n) or na.endswith(n) or n.startswith(na) or n.endswith(na))):
                        chain.insert(0, m)
                        unchained_mentions.remove(m)

                last = longest_mention
                longest_mention = sorted(chain, key=lambda m: len(m.text), reverse=True)[0]
                potential_antecedents = [(m.text, m) for m in unchained_mentions] # if m.tag == mention.tag

            for mention in chain:
                mention.cluster_mention = longest_mention.text
            
    def gen_cands(self):
        for mention in self.mentions:
            mention.gen_cands()

In [50]:
class Mention:
    
    def __init__(self, text, ent_str, span,
                 small_context=None,
                 cluster_mention=None, 
                 coref=True, 
                 disamb=False,
                 necounts=None,
                 rd=None):
        
        self.text = text
        self.ent_str = ent_str
        self.begin, self.end = span
        self.small_context = small_context
        self.cluster_mention = cluster_mention
        self.coref = coref
        self.disamb = disamb
        self.necounts = necounts
        self.rd = rd
        self.cands = []
        
    def add_dismb_cands(self, mention):
        res = []
        mention_title = mention.title().replace(' ', '_')
        res.append(mention_title)
        if mention_title != self.rd.get(mention_title, mention_title):
            res.append(self.rd[mention_title])
        mention_disamb = mention_title + '_(disambiguation)'

        if mention_title in dis_dict:
            res.extend(dis_dict[mention_title])
        if mention_disamb in dis_dict:
            res.extend(dis_dict[mention_disamb])

        return res
    
    def gen_cands(self):
        nfs = set()
        nfs.update(get_normalised_forms(self.text))
        if self.coref:
            nfs.update(get_normalised_forms(self.cluster_mention))

        [self.cands.extend(self.necounts.get(nf, [])) for nf in nfs]
        if self.disamb:
            [self.cands.extend(self.add_dismb_cands(nf)) for nf in nfs]

        self.cands = equalize_len(list(unique_everseen(self.cands)), 100, pad='')

## Candidate Generation

In [51]:
split = 'train'

In [52]:
conll_train = pickle_load(join(DATA_PATH, 'Conll', f'conll-{split}.pickle'))
conll_raw = pickle_load(join(DATA_PATH, 'Conll', 'conll_raw_text.pickle'))
id2c = conll_raw[split]
_, examples = conll_train

In [53]:
docid2mention_tups = defaultdict(list)
for example in examples:
    doc_id, mention_tup = example
    docid2mention_tups[doc_id].append(example)

In [58]:
def run_cand_eval(docs, num_cands=100):
    covered = 0
    total = 0
    
    for doc in docs:
        for mention in doc.mentions:
            total += 1
            # ent_str = redirects.get(mention.ent_str, mention.ent_str)
            ent_str = mention.ent_str
            cands = mention.cands[:num_cands]

            if ent_str in cands:
                covered += 1
    print(covered, total)
    return covered / total

In [59]:
for coref in [False, True]:
    for disamb in [True, False]:
        for num_cands in [50, 100, 128, 200, 256, 500]:
            docs = [Doc(id2c[doc_id],
                        mention_tups,
                        coref=coref, 
                        disamb=disamb, 
                        necounts=necounts,
                        rd=redirects) 
                    for doc_id, mention_tups in docid2mention_tups.items()]
            for doc in docs:
                doc.gen_cands()
            result = run_cand_eval(docs, num_cands=num_cands)

            print_str  = f'Coref: {coref}, disamb: {disamb}, '
            print_str += f'num_cands: {num_cands}, result: {result}'
            print(print_str)

16819 18546
Coref: False, disamb: True, num_cands: 50, result: 0.9068801897983393
17049 18546
Coref: False, disamb: True, num_cands: 100, result: 0.9192817858298286
17049 18546
Coref: False, disamb: True, num_cands: 128, result: 0.9192817858298286
17049 18546
Coref: False, disamb: True, num_cands: 200, result: 0.9192817858298286
17049 18546
Coref: False, disamb: True, num_cands: 256, result: 0.9192817858298286
17049 18546
Coref: False, disamb: True, num_cands: 500, result: 0.9192817858298286
16060 18546
Coref: False, disamb: False, num_cands: 50, result: 0.8659549228944247
16243 18546
Coref: False, disamb: False, num_cands: 100, result: 0.8758222797368704
16243 18546
Coref: False, disamb: False, num_cands: 128, result: 0.8758222797368704
16243 18546
Coref: False, disamb: False, num_cands: 200, result: 0.8758222797368704
16243 18546
Coref: False, disamb: False, num_cands: 256, result: 0.8758222797368704
16243 18546
Coref: False, disamb: False, num_cands: 500, result: 0.8758222797368704
