In [1]:
import os
import os.path as osp
import pandas as pd
import numpy as np
import snorkel
import dotenv
from snorkel.parser import TextDocPreprocessor
from snorkel.parser import CorpusParser
from snorkel import SnorkelSession

dotenv.load_dotenv('../env.sh')
%run ../src/supervision.py
corpus_dir = osp.join(os.environ['DATA_DIR'], 'articles', 'corpus', 'corpus_00')
corpus_docs_dir = osp.join(corpus_dir, 'links')
collection_dir = osp.join(os.environ['REPO_DATA_DIR'], 'brat', 'collection_02')
session = SnorkelSession()

In [2]:
# Clear existing candidates, if need be

from snorkel.models import Candidate
ct = session.query(Candidate).count()
# if ct > 0:
#     # Clear all existing candidates (don't let extractors do it)
#     # See: https://github.com/HazyResearch/snorkel/blob/master/snorkel/candidates.py#L47
#     # *This seems to always cause a database lock somehow -- perhaps it needs to be done with autocommit but
#     # for now a workaround is to do this at the beginning and the restart the kernel
#     from snorkel.models import Candidate
#     ndelete = session.query(Candidate).delete()
#     session.commit()
#     # Restart kernel
ct

0

In [3]:
tags = pd.read_csv(osp.join(corpus_dir, 'tags.csv'))
tags = tags[tags['type'].isin(ENT_TYPES)]
print(tags['type'].value_counts())
tags.head()

CYTOKINE                22425
IMMUNE_CELL_TYPE        20807
TRANSCRIPTION_FACTOR    16721
Name: type, dtype: int64


Unnamed: 0,id,type,ent_id,ent_lbl,ent_prefid,start_chr,end_chr,start_wrd,end_wrd,text
1,PMC5743442,IMMUNE_CELL_TYPE,CTBFBDE5121B6748D1,Th17,CTBFBDE5121B6748D1,65,69,8,9,Th17
2,PMC5743442,IMMUNE_CELL_TYPE,CTBFBDE5121B6748D1,Th17,CTBFBDE5121B6748D1,87,91,11,12,Th17
3,PMC5743442,CYTOKINE,CKFD4CA0B2B4BC3AE4,TGF-β,CKFD4CA0B2B4BC3AE4,174,179,25,26,TGF-β
4,PMC5743442,IMMUNE_CELL_TYPE,CTBFBDE5121B6748D1,Th17,CTBFBDE5121B6748D1,199,203,29,30,Th17
6,PMC5743442,CYTOKINE,CKC59F8990F767EBD4,TGF-β,CKFD4CA0B2B4BC3AE4,275,280,42,43,TGF-β


In [4]:
#doc_preprocessor = TextDocPreprocessor(corpus_docs_dir, max_docs=10)
doc_preprocessor = TextDocPreprocessor(corpus_docs_dir)

In [5]:
def get_annotated_doc_ids():
    import glob
    return [osp.splitext(osp.basename(f))[0] for f in glob.glob(osp.join(collection_dir, '*.txt'))]
annotated_ids = get_annotated_doc_ids()
# Show frequency of docs that are annotated AND have tags of some kind
pd.Series({did:did in tags['id'].values for did in annotated_ids}).value_counts()

True     62
False     3
dtype: int64

In [6]:
from string import punctuation

def offsets_to_token(left, right, offset_array, lemmas, punc=set(punctuation)):
    token_start, token_end = None, None
    for i, c in enumerate(offset_array):
        if left >= c:
            token_start = i
        if c > right and token_end is None:
            token_end = i
            break
    token_end = len(offset_array) - 1 if token_end is None else token_end
    token_end = token_end - 1 if lemmas[token_end - 1] in punc else token_end
    return range(token_start, token_end)


class EntityTagger(object):

    def __init__(self, tags):   
        self.tags = tags.set_index('id')
        self.reset_stats()

    def reset_stats(self):
        self.stats = {'docs': set(), 'found': set()}
        return self
    
    def get_stats(self):
        return dict(
            n_tags=len(self.tags), 
            n_docs=len(self.stats['docs']),
            n_docs_found=len(self.stats['found']),
            pct_docs_found=100*len(self.stats['found'])/len(self.stats['docs'])
        )
    
    def tag(self, parts):
        """Tag tokens in a single sentence"""
        # Extract doc id (e.g. PMC123932) and character offsets of sentence
        docid, _, _, sent_start, sent_end = parts['stable_id'].split(':')
        self.stats['docs'].add(docid)
        if docid not in self.tags.index:
            return parts
        self.stats['found'].add(docid)
        tags = self.tags.loc[[docid]]
        sent_start, sent_end = int(sent_start), int(sent_end)
        for r in tags.itertuples():
            tag_start, tag_end = r.start_chr, r.end_chr
            # Determine whether or not the tag is in this sentence
            if not (sent_start <= tag_start <= sent_end):
                continue
            offsets = [offset + sent_start for offset in parts['char_offsets']]
            tkn_idx_rng = offsets_to_token(tag_start, tag_end, offsets, parts['lemmas'])
            for tkn_idx in tkn_idx_rng:
                parts['entity_types'][tkn_idx] = r.type.lower()
                parts['entity_cids'][tkn_idx] = r.ent_id + ':' + r.ent_prefid
        return parts

In [7]:
from snorkel.parser import CorpusParser, Spacy

tagger = EntityTagger(tags)
corpus_parser = CorpusParser(fn=tagger.tag)
corpus_parser.apply(list(doc_preprocessor))

  0%|          | 0/555 [00:00<?, ?it/s]

Clearing existing...
Running UDF...


100%|██████████| 555/555 [06:33<00:00,  1.41it/s]


In [8]:
# This will show how many documents didn't have any tagged
# (or that were otherwise not included in tagging but included
# here in parsing -- which should be rare)
tagger.get_stats()

{'n_tags': 59953,
 'n_docs': 555,
 'n_docs_found': 520,
 'pct_docs_found': 93.69369369369369}

In [9]:
from snorkel.models import Document, Sentence

print("Documents:", session.query(Document).count())
print("Sentences:", session.query(Sentence).count())

Documents: 555
Sentences: 72223


In [10]:
docs = session.query(Document).all()

In [11]:
all_ids = [doc.name for doc in docs]
tagged_ids = list(np.intersect1d(all_ids, tags['id']))
dev_ids = list(set(annotated_ids))
train_ids = [i for i in tagged_ids if i not in dev_ids]
len(train_ids), len(dev_ids)

(458, 65)

In [12]:
train_sents, dev_sents = set(), set()
for i, doc in enumerate(docs):
    if doc.name not in tagged_ids:
        continue
    for s in doc.sentences:
        if doc.name in train_ids:
            train_sents.add(s)
        elif doc.name in dev_ids:
            dev_sents.add(s)
        else:
            raise Exception('ID <{0}> not found in any id set'.format(doc.name))

In [13]:
from snorkel.candidates import PretaggedCandidateExtractor
classes = get_candidate_classes()
candidate_extractors = [
    PretaggedCandidateExtractor(c.subclass, c.entity_types)
    for c in classes.values()
]

In [14]:
for k, sents in enumerate([train_sents, dev_sents]):
    for extractor in candidate_extractors:
        relation_class = extractor.udf_init_kwargs['candidate_class']
        extractor.apply(sents, split=k, clear=False)
        print('Number of candidates generated for split {}, relation type {}: {}'.format(
            k, relation_class.__name__,
            session.query(relation_class).filter(relation_class.split == k).count()
        ))

  0%|          | 141/70540 [00:00<00:50, 1400.73it/s]

Running UDF...


100%|██████████| 70540/70540 [00:41<00:00, 1715.93it/s]
  0%|          | 315/70540 [00:00<00:22, 3149.09it/s]

Number of candidates generated for split 0, relation type InducingCytokine: 11836
Running UDF...


100%|██████████| 70540/70540 [00:26<00:00, 2652.90it/s]
  0%|          | 213/70540 [00:00<00:33, 2100.63it/s]

Number of candidates generated for split 0, relation type SecretedCytokine: 11836
Running UDF...


100%|██████████| 70540/70540 [00:27<00:00, 2564.78it/s]
  4%|▍         | 61/1449 [00:00<00:02, 598.80it/s]

Number of candidates generated for split 0, relation type InducingTranscriptionFactor: 8477
Running UDF...


100%|██████████| 1449/1449 [00:01<00:00, 1169.34it/s]
  4%|▎         | 54/1449 [00:00<00:03, 360.02it/s]

Number of candidates generated for split 1, relation type InducingCytokine: 593
Running UDF...


100%|██████████| 1449/1449 [00:01<00:00, 1342.74it/s]
  8%|▊         | 113/1449 [00:00<00:01, 1109.47it/s]

Number of candidates generated for split 1, relation type SecretedCytokine: 593
Running UDF...


100%|██████████| 1449/1449 [00:01<00:00, 1404.22it/s]

Number of candidates generated for split 1, relation type InducingTranscriptionFactor: 406





## Load Gold Labels

In [15]:
# Read csv export with annotated relations to load:
relations = pd.read_csv(osp.join(corpus_dir, 'relations.csv'))
relations.head()

Unnamed: 0,e1_end_chr,e1_start_chr,e1_text,e1_typ,e2_end_chr,e2_start_chr,e2_text,e2_typ,id,rel_typ
0,24,19,Gfi-1,TRANSCRIPTION_FACTOR,85,81,Th17,IMMUNE_CELL_TYPE,PMC2646571,Differentiation
1,44,39,TGF-β,CYTOKINE,85,81,Th17,IMMUNE_CELL_TYPE,PMC2646571,Induction
2,44,39,TGF-β,CYTOKINE,125,90,CD103+ inducible regulatory T cells,IMMUNE_CELL_TYPE,PMC2646571,Induction
3,24,19,Gfi-1,TRANSCRIPTION_FACTOR,125,90,CD103+ inducible regulatory T cells,IMMUNE_CELL_TYPE,PMC2646571,Differentiation
4,371,366,Gfi-1,TRANSCRIPTION_FACTOR,436,433,Th2,IMMUNE_CELL_TYPE,PMC2646571,Differentiation


In [16]:
relations['rel_typ'].value_counts()

Induction          129
Secretion          113
Differentiation    101
Name: rel_typ, dtype: int64

In [17]:
relations.groupby(['e1_typ', 'e2_typ', 'rel_typ']).size()

e1_typ                e2_typ            rel_typ        
CYTOKINE              IMMUNE_CELL_TYPE  Induction          129
                                        Secretion          113
TRANSCRIPTION_FACTOR  IMMUNE_CELL_TYPE  Differentiation    101
dtype: int64

In [145]:
# # Reset annotation tables (if loading them below fails)
# session.execute('DELETE FROM stable_label;')
# session.execute('DELETE FROM gold_label;')
# session.execute('DELETE FROM gold_label_key;')
# from snorkel.models import GoldLabel, GoldLabelKey, StableLabel
# session.commit()
# session.query(StableLabel).count(), session.query(GoldLabel).count(), session.query(GoldLabelKey).count()

In [18]:
from snorkel.models import StableLabel
from snorkel.db_helpers import reload_annotator_labels

def get_stable_id(r):
    return '{}::span:{}:{}'.format(r['id'], r['start_chr'], r['end_chr']-1)

def reload_labels(
    session, candidate_class, annotator_name, split, 
    filter_label_split=True, create_missing_cands=False):
    """Reloads stable annotator labels into the AnnotatorLabel table"""
    from snorkel.models import GoldLabel, GoldLabelKey, StableLabel, Context
    from future.utils import iteritems
    # Sets up the AnnotatorLabelKey to use
    ak = session.query(GoldLabelKey).filter(GoldLabelKey.name == annotator_name).first()
    if ak is None:
        ak = GoldLabelKey(name=annotator_name)
        session.add(ak)
        session.commit()

    labels = []
    missed = []
    sl_query = session.query(StableLabel).filter(StableLabel.annotator_name == annotator_name)
    sl_query = sl_query.filter(StableLabel.split == split) if filter_label_split else sl_query
    for sl in sl_query.all():
        context_stable_ids = sl.context_stable_ids.split('~~')

        # Check for labeled Contexts
        # TODO: Does not create the Contexts if they do not yet exist!
        contexts = []
        for stable_id in context_stable_ids:
            context = session.query(Context).filter(Context.stable_id == stable_id).first()
            if context:
                contexts.append(context)
        if len(contexts) < len(context_stable_ids):
            missed.append(sl)
            continue

        # Check for Candidate
        # Assemble candidate arguments
        candidate_args  = {'split' : split}
        for i, arg_name in enumerate(candidate_class.__argnames__):
            candidate_args[arg_name] = contexts[i]

        # Assemble query and check
        candidate_query = session.query(candidate_class)
        for k, v in iteritems(candidate_args):
            candidate_query = candidate_query.filter(getattr(candidate_class, k) == v)
        candidate = candidate_query.first()

        # Optionally construct missing candidates
        if candidate is None and create_missing_cands:
            candidate = candidate_class(**candidate_args)

        # If candidate is none, mark as missed and continue
        if candidate is None:
            missed.append(sl)
            continue

        # Check for AnnotatorLabel, otherwise create
        label = session.query(GoldLabel).filter(GoldLabel.key == ak).filter(GoldLabel.candidate == candidate).first()
        if label is None:
            label = GoldLabel(candidate=candidate, key=ak, value=sl.value)
            session.add(label)
            labels.append(label)

    session.commit()
    print("AnnotatorLabels created: %s, missed: %s" % (len(labels), len(missed)))
    return missed, labels
    
def load_external_labels(session, relations, candidate_class, annotator_name='gold'):
    print(annotator_name)
    for i, r in relations.iterrows():    

        # We check if the label already exists, in case this cell was already executed
        e1_id = get_stable_id(r.filter(regex='^e1_|^id$').rename(lambda v: v.replace('e1_', '')))
        e2_id = get_stable_id(r.filter(regex='^e2_|^id$').rename(lambda v: v.replace('e2_', '')))
        
        context_stable_ids = "~~".join([e1_id, e2_id])
        query = session.query(StableLabel)\
            .filter(StableLabel.context_stable_ids == context_stable_ids)\
            .filter(StableLabel.annotator_name == annotator_name)
        if query.count() == 0:
            session.add(StableLabel(
                context_stable_ids=context_stable_ids,
                annotator_name=annotator_name,
                value=1
            ))
            
    session.commit()
    # This function will create GoldLabel records for each StableLabel above after
    # selecting them based on annotator_name.  The annotator name should be different
    # for each candidate class if they might have identical context_stable_ids since
    # otherwise as written above only the first value for the same annotator + context_stable_ids
    # will be saved.
    # Other notes: split will be used to find candidates necessary to create GoldLabels though
    # it is not necessary for StableLabel filtering in this case (thus filter_label_split=False)
    # because the labels were not created with a split above
    #reload_annotator_labels(
    return reload_labels(
        session, candidate_class, annotator_name, split=1, 
        filter_label_split=False, create_missing_cands=False)
    
cand_summary = {}
for extractor in candidate_extractors:
    relation_class = extractor.udf_init_kwargs['candidate_class']
    label = classes[relation_class.__name__].label
    field = classes[relation_class.__name__].field
    df = relations[relations['rel_typ'] == label]
    assert len(df) > 0, 'Found no records for relation type {}'.format(label)
    print('Found {} relations for type {}'.format(len(df), label))
    cand_summary[relation_class.__name__] = load_external_labels(session, df, relation_class, annotator_name=field)

Found 129 relations for type Induction
inducing_cytokine
AnnotatorLabels created: 112, missed: 17
Found 113 relations for type Secretion
secreted_cytokine
AnnotatorLabels created: 65, missed: 48
Found 101 relations for type Differentiation
inducing_transcription_factor
AnnotatorLabels created: 59, missed: 42


In [20]:
from snorkel.models import Context 

# Show which entity types and relations were unable to be matched with spans
# extracted and inserted into snorkel db
def summarize_missing_candidates(cand_summary):
    df = []
    for c in cand_summary:
        class_name = classes[c].name
        missed = cand_summary[c][0]
        for mc in missed:
            doc_id = mc.context_stable_ids.split('::')[0]
            ids = mc.context_stable_ids.split('~~')
            typs = []
            for i, sid in enumerate(ids):
                ctx = session.query(Context).filter(Context.stable_id == sid).all()
                if len(ctx) == 0:
                    typs.append(classes[c].entity_types[i])
            df.append((class_name, doc_id, ','.join(typs)))
    return pd.DataFrame(df, columns=['relation', 'doc_id', 'missing'])
df_miss = summarize_missing_candidates(cand_summary)
df_miss.groupby(['relation', 'missing']).size()

relation                     missing                              
InducingCytokine             cytokine                                  4
                             cytokine,immune_cell_type                 1
                             immune_cell_type                         12
InducingTranscriptionFactor                                            1
                             immune_cell_type                          5
                             transcription_factor                     28
                             transcription_factor,immune_cell_type     8
SecretedCytokine             cytokine                                  9
                             cytokine,immune_cell_type                 3
                             immune_cell_type                         36
dtype: int64

In [21]:
# Find documents with most annotations unable to be matched and either improve tagging or change annotations
df_miss.groupby(['relation', 'doc_id']).size().sort_values(ascending=False).head(15)

relation                     doc_id    
SecretedCytokine             PMC3046151    13
                             PMC2196041    11
InducingTranscriptionFactor  PMC2587175    11
                             PMC2646571     8
                             PMC3173465     7
                             PMC2783637     6
SecretedCytokine             PMC2193209     5
InducingCytokine             PMC2196041     5
SecretedCytokine             PMC3650071     4
InducingCytokine             PMC3173465     4
InducingTranscriptionFactor  PMC3304099     4
SecretedCytokine             PMC4385920     4
InducingTranscriptionFactor  PMC4474185     3
InducingCytokine             PMC2646571     3
SecretedCytokine             PMC3204990     2
dtype: int64

In [22]:
from snorkel.annotations import load_gold_labels, load_matrix
from snorkel.models import Candidate

L_gold = {}
for c in classes:
    cids_query = get_cids_query(session, classes[c], split=1)
    L_gold[c] = load_gold_labels(session, annotator_name=classes[c].field, split=1, cids_query=cids_query)

In [23]:
for c in L_gold:
    print(c, L_gold[c].shape)

InducingCytokine (593, 1)
SecretedCytokine (593, 1)
InducingTranscriptionFactor (406, 1)


In [24]:
for c in classes.values():
    for split in [0, 1]:
        n = session.query(c.subclass).filter(c.subclass.split == split).count()
        print('Candidate counts: {} (split {}) -> {}'.format(c.name, split, n))

Candidate counts: InducingCytokine (split 0) -> 11836
Candidate counts: InducingCytokine (split 1) -> 593
Candidate counts: SecretedCytokine (split 0) -> 11836
Candidate counts: SecretedCytokine (split 1) -> 593
Candidate counts: InducingTranscriptionFactor (split 0) -> 8477
Candidate counts: InducingTranscriptionFactor (split 1) -> 406


In [25]:
# Make sure that inducing/secreted cytokine labels are mutally exclusive
L_df = pd.DataFrame(np.hstack((
    L_gold[classes.inducing_cytokine.name].toarray(), 
    L_gold[classes.secreted_cytokine.name].toarray()
)))
L_df.groupby([0, 1]).size()

0  1
0  0    416
   1     65
1  0    112
dtype: int64

In [26]:
pd.Series([
    type(L_gold[classes.inducing_cytokine.name].get_candidate(session, i)).__name__ 
    for i in range(L_gold[classes.inducing_cytokine.name].shape[0])
]).value_counts()

InducingCytokine    593
dtype: int64