# Calculate LEA coreference evaluation

In [63]:
# Load ground-truth annotated entity mentions
import os
import pandas as pd

annotations_dirpath = '/data/fanfiction_ao3/annotated_10fandom/dev/entity_clusters'

gold_entities = {} # fic_id: {cluster_name: {(chapter_id, paragraph_id, token_id_start, token_id_end), ...}}

for fname in os.listdir(annotations_dirpath):
    
    fic_id = int(fname.split('_')[1])
    gold_entities[fic_id] = {}
    
    df = pd.read_csv(os.path.join(annotations_dirpath, fname))
    for colname in df.columns:
        gold_entities[fic_id][colname] = set()
        for mention in df[colname].dropna():
            parts = mention.split('.')
            chapter_id = int(parts[0])
            paragraph_id = int(parts[1])
            if '-' in parts[2]:
                token_id_start = int(parts[2].split('-')[0])
                token_id_end = int(parts[2].split('-')[-1])
            else:
                token_id_start = int(parts[2])
                token_id_end = int(parts[2])
                
            gold_entities[fic_id][colname].add((chapter_id, paragraph_id, token_id_start, token_id_end))

gold_entities

{1621415: {'1 (Skye)': {(1, 5, 2, 2),
   (1, 5, 11, 11),
   (1, 5, 17, 17),
   (1, 5, 25, 25),
   (1, 6, 13, 13),
   (1, 6, 31, 31),
   (1, 7, 4, 4),
   (1, 7, 11, 11),
   (1, 7, 19, 19),
   (1, 7, 33, 33),
   (1, 8, 13, 13),
   (1, 9, 6, 6),
   (1, 9, 8, 8),
   (1, 9, 26, 26),
   (1, 9, 30, 30),
   (1, 10, 22, 22),
   (1, 10, 53, 53),
   (1, 10, 57, 57),
   (1, 10, 59, 59),
   (1, 10, 89, 89),
   (1, 11, 13, 13),
   (1, 11, 16, 16),
   (1, 11, 24, 24),
   (1, 13, 1, 1),
   (1, 14, 7, 7),
   (1, 14, 11, 11),
   (1, 14, 14, 14),
   (1, 14, 20, 20),
   (1, 14, 25, 25),
   (1, 16, 6, 6),
   (1, 16, 51, 51),
   (1, 16, 72, 72),
   (1, 16, 74, 74),
   (1, 17, 25, 25),
   (1, 17, 29, 29),
   (1, 17, 33, 33),
   (1, 18, 7, 7),
   (1, 18, 11, 11),
   (1, 18, 20, 20),
   (1, 18, 30, 30),
   (1, 19, 12, 12),
   (1, 19, 16, 16),
   (1, 20, 2, 2),
   (1, 20, 7, 7),
   (1, 20, 26, 26),
   (1, 20, 36, 36),
   (1, 20, 40, 40),
   (1, 20, 57, 57),
   (1, 20, 73, 73),
   (1, 20, 82, 82),
   (1, 20, 88,

In [64]:
def extract_entity_mentions(text):
    """ Return token start and endpoints of entity mentions embedded in text. """
    
    token_count = 1
    entities = {} # cluster_name: {(token_id_start, token_id_end), ...}
    
    tokens = text.split(' ')
    for i, token in enumerate(tokens):
        if token.startswith('($_'): # entity cluster name
            if not token in entities:
                entities[token] = set()
                
            mention = tokens[i-1]
            mention_len = len(mention.split('_'))
            token_id_start = token_count - 1
            token_id_end = (token_count - 1) + (mention_len - 1)
            
            token_count += mention_len - 1 # for the underscore-connected mentions
                
            entities[token].add((token_id_start, token_id_end))
            
        else:
            # Advance token count
            token_count += 1
            
    return entities

In [65]:
# Load entity cluster predictions
import os
import pandas as pd

predictions_dirpath = '/data/fanfiction_ao3/annotated_10fandom/dev/pipeline_output/char_coref_stories'

predicted_entities = {} # fic_id: {cluster_name: {(chapter_id, paragraph_id, token_id_start, token_id_end), ...}}

for fname in sorted(os.listdir(predictions_dirpath))[:1]:
    print(fname)
    df = pd.read_csv(os.path.join(predictions_dirpath, fname))
    for row in list(df.itertuples()):
        fic_id = row.fic_id
        chapter_id = row.chapter_id
        para_id = row.para_id
        entities = extract_entity_mentions(row.text_tokenized)
#         print(entities)
#         print(row.text_tokenized)
        
        if not fic_id in predicted_entities:
            predicted_entities[fic_id] = {}
        
        for cluster_name in entities:
            if not cluster_name in predicted_entities[fic_id]:
                predicted_entities[fic_id][cluster_name] = set()
            
            for mention in entities[cluster_name]:
                token_id_start = mention[0]
                token_id_end = mention[1]
                predicted_entities[fic_id][cluster_name].add((chapter_id, para_id, token_id_start, token_id_end))
                
predicted_entities

allmarvel_1621415.coref.csv


{1621415: {'($_HYDRA)': {(1, 2, 6, 6), (1, 13, 20, 20)},
  '($_Skye)': {(1, 5, 2, 2),
   (1, 6, 29, 29),
   (1, 9, 6, 6),
   (1, 9, 30, 30),
   (1, 10, 53, 53),
   (1, 11, 16, 16),
   (1, 13, 1, 1),
   (1, 14, 7, 7),
   (1, 16, 52, 52),
   (1, 17, 29, 29),
   (1, 18, 2, 2),
   (1, 19, 12, 12),
   (1, 19, 33, 33),
   (1, 20, 53, 53),
   (1, 20, 132, 132),
   (1, 21, 11, 11),
   (1, 22, 65, 65),
   (1, 24, 8, 8),
   (1, 30, 5, 5),
   (1, 30, 15, 15),
   (1, 31, 1, 1),
   (1, 31, 18, 18),
   (1, 31, 23, 23),
   (1, 32, 1, 1),
   (1, 32, 25, 25),
   (1, 32, 31, 31),
   (1, 32, 36, 36),
   (1, 32, 86, 86),
   (1, 32, 92, 92),
   (1, 32, 96, 96),
   (1, 32, 101, 101),
   (1, 32, 111, 111),
   (1, 33, 21, 21),
   (1, 33, 34, 34),
   (1, 33, 45, 45),
   (1, 33, 59, 59),
   (1, 33, 65, 65),
   (1, 34, 10, 10),
   (1, 34, 25, 25),
   (1, 34, 33, 33),
   (1, 34, 36, 36),
   (1, 34, 51, 51),
   (1, 34, 64, 64),
   (1, 35, 24, 24),
   (1, 35, 43, 43),
   (1, 35, 53, 53),
   (1, 36, 14, 14),
   (1, 

In [66]:
import itertools

def links(entity_mentions):
    """ Returns all the links in an entity between mentions """
    
    if len(entity_mentions) == 1: # self-link
        links = {list(entity_mentions)[0], list(entity_mentions)[0]}

    else:
        links = set(itertools.combinations(entity_mentions, 2))
        
    return links

In [77]:
import numpy as np
from IPython.core.debugger import set_trace

def lea_recall(predicted_entities, gold_entities):
    
    fic_recalls = {}
    
    for fic_id in gold_entities:
        
        cluster_resolutions = {}
        cluster_sizes = {}
        
        for gold_cluster, gold_mentions in gold_entities[fic_id].items():
            gold_links = links(gold_mentions)
            
            cluster_resolution = 0
            
            for predicted_cluster, predicted_mentions in predicted_entities[fic_id].items():
                predicted_links = links(predicted_mentions)
                
                cluster_resolution += len(predicted_links.intersection(gold_links))
                
            cluster_resolution = cluster_resolution/len(gold_links)
            cluster_resolutions[gold_cluster] = cluster_resolution
            cluster_sizes[gold_cluster] = len(gold_mentions)
            
        # take importance (size) of clusters into account
#         print(cluster_resolutions)
        fic_recalls[fic_id] = sum([cluster_sizes[c] * cluster_resolutions[c] for c in gold_entities[fic_id]])/sum(cluster_sizes.values())
        
    # Total recall as mean across fics
#     print(fic_recalls)
    total_recall = np.mean(list(fic_recalls.values()))
    return total_recall

In [78]:
import numpy as np
from IPython.core.debugger import set_trace

def lea_precision(predicted_entities, gold_entities):
    
    fic_precisions = {}
    
    for fic_id in gold_entities:
        
        cluster_resolutions = {}
        cluster_sizes = {}
        
        for predicted_cluster, predicted_mentions in predicted_entities[fic_id].items():
            predicted_links = links(predicted_mentions)
            
            cluster_resolution = 0
            
            for gold_cluster, gold_mentions in gold_entities[fic_id].items():
                gold_links = links(gold_mentions)
                cluster_resolution += len(predicted_links.intersection(gold_links))
            
            cluster_resolution = cluster_resolution/len(predicted_links)
            cluster_resolutions[predicted_cluster] = cluster_resolution
            cluster_sizes[predicted_cluster] = len(predicted_mentions)
            
        # take importance (size) of clusters into account
#         print(cluster_resolutions)
        fic_precisions[fic_id] = sum([cluster_sizes[c] * cluster_resolutions[c] for c in predicted_entities[fic_id]])/sum(cluster_sizes.values())
        
    # Total precision as mean across fics
#     print(fic_precisions)
    total_precision = np.mean(list(fic_precisions.values()))
    return total_precision

In [69]:
lea_recall(predicted_entities, gold_entities)

{'1 (Skye)': 0.01139211665527455, '5 (Coulson)': 0.042767295597484274, '2 (Jemma)': 0.0, '3 (Fitz)': 0.0, '4 (Trip)': 0.0}
{1621415: 0.024993739043325823}


0.024993739043325823

In [70]:
lea_precision(predicted_entities, gold_entities)

{'($_HYDRA)': 0.0, '($_Skye)': 0.012008281573498964, '($_Jemma)': 0.43859649122807015, '($_That_Coulson)': 0.0640218878248974, '($_SHIELD)': 0.0, '($_Stars)': 0.0}
{1621415: 0.08021746118261987}


0.08021746118261987

In [71]:
gold_entities

{1621415: {'1 (Skye)': {(1, 5, 2, 2),
   (1, 5, 11, 11),
   (1, 5, 17, 17),
   (1, 5, 25, 25),
   (1, 6, 13, 13),
   (1, 6, 31, 31),
   (1, 7, 4, 4),
   (1, 7, 11, 11),
   (1, 7, 19, 19),
   (1, 7, 33, 33),
   (1, 8, 13, 13),
   (1, 9, 6, 6),
   (1, 9, 8, 8),
   (1, 9, 26, 26),
   (1, 9, 30, 30),
   (1, 10, 22, 22),
   (1, 10, 53, 53),
   (1, 10, 57, 57),
   (1, 10, 59, 59),
   (1, 10, 89, 89),
   (1, 11, 13, 13),
   (1, 11, 16, 16),
   (1, 11, 24, 24),
   (1, 13, 1, 1),
   (1, 14, 7, 7),
   (1, 14, 11, 11),
   (1, 14, 14, 14),
   (1, 14, 20, 20),
   (1, 14, 25, 25),
   (1, 16, 6, 6),
   (1, 16, 51, 51),
   (1, 16, 72, 72),
   (1, 16, 74, 74),
   (1, 17, 25, 25),
   (1, 17, 29, 29),
   (1, 17, 33, 33),
   (1, 18, 7, 7),
   (1, 18, 11, 11),
   (1, 18, 20, 20),
   (1, 18, 30, 30),
   (1, 19, 12, 12),
   (1, 19, 16, 16),
   (1, 20, 2, 2),
   (1, 20, 7, 7),
   (1, 20, 26, 26),
   (1, 20, 36, 36),
   (1, 20, 40, 40),
   (1, 20, 57, 57),
   (1, 20, 73, 73),
   (1, 20, 82, 82),
   (1, 20, 88,

In [79]:
# Test calculation with toy examples

import itertools

# set(itertools.combinations({(1,3), (1,4), (2,2), (3,5)}, 2))
test_gold = {1: {'A': {(1,1,1,1), (1,1,2,2), (1,1,3,3)},
                'B': {(1,1,4,4), (1,1,5,5), (1,1,6,6)}
                }}

test_predicted = {1: {'A': {(1,1,1,1), (1,1,2,2), (1,1,3,3), (1,1,6,6)},
                'B': {(1,1,4,4), (1,1,5,5)}
                }}

print(lea_recall(test_predicted, test_gold))
print(lea_precision(test_predicted, test_gold))

0.6666666666666666
0.6666666666666666


# Create personal coref annotation interface (token id subscripts)

In [2]:
def add_token_subscript(text):
    numbered_tokens = [el for el in enumerate(text.split())]
    subscripted = [f'{tok}<sub>{tok_num+1}</sub>' for tok_num, tok in numbered_tokens]
    return ' '.join(subscripted)

In [13]:
import os
import pandas as pd

pd.set_option('display.max_colwidth', -1)
annotation_dirpath = '/data/fanfiction_ao3/annotated_10fandom/'
csv_dirpath = os.path.join(annotation_dirpath, 'fics')
subscripts_dirpath = os.path.join(annotation_dirpath, 'subscripted')
fnames = os.listdir(csv_dirpath)

fandoms = [
    'allmarvel'
]

for fandom in fandoms:
    for fname in fnames:
        if fname.endswith('.csv') and fname.startswith(fandom):
            data = pd.read_csv(os.path.join(csv_dirpath, fname))
            data['annotation_text'] = data['text_tokenized'].map(add_token_subscript)
            data.loc[:, ['chapter_id', 'para_id', 'annotation_text']].to_html(os.path.join(subscripts_dirpath, f'{fname[:-4]}_subscripts.html'), escape=False, index=False)

# Load data for preliminary annotation dataset

In [2]:
import random

all_fandoms = [
#     'allmarvel',
    'supernatural',
    'harrypotter',
    'dcu',
    'sherlock',
    'teenwolf',
    'starwars',
    'drwho',
    'tolkien',
    'dragonage',
]

random.sample(all_fandoms, 4)

['teenwolf', 'harrypotter', 'sherlock', 'tolkien']

In [8]:
import os, shutil
import random

old_seeds = [9, 12, 1234, 99, 120]
current_seed = 120
random.seed(current_seed)

dataset = 'complete_en_1k-50k'
fandoms = [
#     'allmarvel',
#     'supernatural',
#     'harrypotter',
#     'dcu',
    'sherlock',
    'teenwolf',
#     'starwars',
#     'drwho',
#     'tolkien',
#     'dragonage',
]

for fandom in fandoms:

    fic_dirpath = f'/data/fanfiction_ao3/{fandom}/{dataset}/fics'
    annotation_dirpath = f'/data/fanfiction_ao3/annotated_10fandom/dev/fics/'
    fnames = os.listdir(fic_dirpath)
    selected = random.sample(fnames, 1)[0]
    print(f'{fandom}: {selected}')
    shutil.copy(os.path.join(fic_dirpath, selected), os.path.join(annotation_dirpath, f'{fandom}_{selected}'))

sherlock: 12828381.csv
teenwolf: 1145590.csv
