In [1]:
import os
import os.path as osp
import pandas as pd
import numpy as np
import snorkel
import dotenv
from snorkel.parser import TextDocPreprocessor
from snorkel.parser import CorpusParser
from snorkel import SnorkelSession

dotenv.load_dotenv('../env.sh')
corpus_dir = osp.join(os.environ['DATA_DIR'], 'articles', 'corpus', 'corpus_00')
corpus_docs_dir = osp.join(corpus_dir, 'links')
collection_dir = osp.join(os.environ['REPO_DATA_DIR'], 'brat', 'collection_01')
session = SnorkelSession()

In [2]:
#doc_preprocessor = TextDocPreprocessor(corpus_docs_dir, max_docs=10)
doc_preprocessor = TextDocPreprocessor(corpus_docs_dir)

In [3]:
tags = pd.read_csv(osp.join(corpus_dir, 'tags.csv'))
tags.head()

Unnamed: 0,id,type,ent_typ_id,ent_typ_lbl,start_chr,end_chr,start_wrd,end_wrd,text
0,PMC5743442,CELL_TYPE,4,Th17,65,69,8,9,Th17
1,PMC5743442,CELL_TYPE,4,Th17,87,91,11,12,Th17
2,PMC5743442,CYTOKINE,81,TGF-β1,174,179,25,26,TGF-β
3,PMC5743442,CELL_TYPE,4,Th17,199,203,29,30,Th17
4,PMC5743442,CYTOKINE,81,TGF-β1,275,280,42,43,TGF-β


In [4]:
def get_annotated_doc_ids():
    import glob
    return [osp.splitext(osp.basename(f))[0] for f in glob.glob(osp.join(collection_dir, '*.txt'))]
annotated_ids = get_annotated_doc_ids()
annotated_ids
{did:did in tags['id'].values for did in annotated_ids}

{'PMC2646571': True,
 'PMC3304099': True,
 'PMC2634967': True,
 'PMC3189223': True,
 'PMC3095633': False,
 'PMC2193209': True,
 'PMC2938478': True,
 'PMC3235500': True,
 'PMC3046151': True,
 'PMC3173465': True}

In [5]:
from string import punctuation

def offsets_to_token(left, right, offset_array, lemmas, punc=set(punctuation)):
    """Find the token range that spans character offsets 
    
    This will find the indexes of tokens, as a range, that span a target character 
    range where the first token index has a character offset <= `left` and the 
    last token index has a character offset > `right`
    
    Example: offsets_to_token(15, 25, [0, 10, 20, 30]) --> range(1, 3)
    
    This is useful for identifying all tokens in a document that span a range
    of characters determined by another process that may have tokenized the same
    document differently.
    """
    token_start, token_end = None, None
    for i, c in enumerate(offset_array):
        if left >= c:
            token_start = i
        if c > right and token_end is None:
            token_end = i
            break
    token_end = len(offset_array) - 1 if token_end is None else token_end
    token_end = token_end - 1 if lemmas[token_end - 1] in punc else token_end
    return range(token_start, token_end)


class EntityTagger(object):

    def __init__(self, tags):   
        self.tags = tags.set_index('id')
        self.reset_stats()

    def reset_stats(self):
        self.stats = {'docs': set(), 'found': set()}
        return self
    
    def get_stats(self):
        return dict(
            n_tags=len(self.tags), 
            n_docs=len(self.stats['docs']),
            n_docs_found=len(self.stats['found']),
            pct_docs_found=100*len(self.stats['found'])/len(self.stats['docs'])
        )
    
    def tag(self, parts):
        """Tag tokens in a single sentence"""
        # Extract doc id (e.g. PMC123932) and character offsets of sentence
        docid, _, _, sent_start, sent_end = parts['stable_id'].split(':')
        self.stats['docs'].add(docid)
        if docid not in self.tags.index:
            return parts
        self.stats['found'].add(docid)
        tags = self.tags.loc[[docid]]
        sent_start, sent_end = int(sent_start), int(sent_end)
        for r in tags.itertuples():
            tag_start, tag_end = r.start_chr, r.end_chr
            # Determine whether or not the tag is in this sentence
            if not (sent_start <= tag_start <= sent_end):
                continue
            offsets = [offset + sent_start for offset in parts['char_offsets']]
            tkn_idx_rng = offsets_to_token(tag_start, tag_end, offsets, parts['lemmas'])
            for tkn_idx in tkn_idx_rng:
                parts['entity_types'][tkn_idx] = r.type.lower()
                parts['entity_cids'][tkn_idx] = r.ent_typ_lbl # TODO: use the id of the entity instead
        return parts

In [6]:
from snorkel.parser import CorpusParser, Spacy

tagger = EntityTagger(tags)
corpus_parser = CorpusParser(fn=tagger.tag)
corpus_parser.apply(list(doc_preprocessor))

  0%|          | 0/510 [00:00<?, ?it/s]

Clearing existing...
Running UDF...


 16%|█▌        | 82/510 [00:41<03:36,  1.98it/s]

0 241 39 44 7 cytokine TGF-β1
0 241 81 85 13 cell_type Th17
242 324 305 308 12 cell_type Th2
325 554 433 436 18 cell_type Th2
325 554 509 513 32 cell_type Th17
555 899 566 569 2 cell_type Th2
555 899 695 698 27 cell_type Th2
555 899 767 771 39 cell_type Th17
555 899 804 807 47 cell_type Th2
555 899 851 856 54 cytokine IL-17
900 1105 935 939 7 cell_type Th17
1481 1614 1533 1536 10 cell_type Th2
1481 1614 1583 1587 18 cell_type Th17
1614 1865 1728 1731 20 cell_type Th1
1614 1865 1736 1739 22 cell_type Th2
1614 1865 1805 1810 39 cytokine IL-17
1614 1865 1812 1816 41 cell_type Th17
1925 2017 1925 1928 0 cell_type Th1
1925 2017 1950 1955 4 cytokine IFN-γ
2018 2168 2018 2021 0 cell_type Th2
2018 2168 2036 2040 3 cytokine IL-4
2018 2168 2042 2046 5 cytokine IL-5
2018 2168 2052 2057 8 cytokine IL-13
2169 2293 2169 2173 0 cell_type Th17
2294 2413 2294 2297 0 cell_type Th1
2294 2413 2299 2302 2 cell_type Th2
2294 2413 2308 2312 5 cell_type Th17
2414 2566 2420 2423 2 cell_type Th1
2414 2566 2428 

100%|██████████| 510/510 [05:15<00:00,  1.62it/s]


In [7]:
# This will show how many documents didn't have any tagged
# (or that were otherwise not included in tagging but included
# here in parsing -- which should be rare)
tagger.get_stats()

{'n_tags': 23745,
 'n_docs': 510,
 'n_docs_found': 399,
 'pct_docs_found': 78.23529411764706}

In [8]:
from snorkel.models import Document, Sentence

print("Documents:", session.query(Document).count())
print("Sentences:", session.query(Sentence).count())

Documents: 510
Sentences: 72944


In [9]:
docs = session.query(Document).all()

In [12]:
all_ids = [doc.name for doc in docs]
tagged_ids = list(np.intersect1d(all_ids, tags['id']))
dev_ids = list(set(annotated_ids))
train_ids = [i for i in tagged_ids if i not in dev_ids]
len(train_ids), len(dev_ids)

(390, 10)

In [13]:
train_sents, dev_sents = set(), set()
for i, doc in enumerate(docs):
    if doc.name not in tagged_ids:
        continue
    for s in doc.sentences:
        if doc.name in train_ids:
            train_sents.add(s)
        elif doc.name in dev_ids:
            dev_sents.add(s)
        else:
            raise Exception('ID <{0}> not found in any id set'.format(doc.name))

In [24]:
# for sent in session.query(Sentence).all():
#     if 'cytokine' in sent.entity_types and 'cell_type' in sent.entity_types:
#         print(set(sent.entity_types))

In [13]:
#sent.entity_types

In [14]:
from snorkel.models import Candidate, candidate_subclass
from snorkel.candidates import PretaggedCandidateExtractor

InducingCytokine = candidate_subclass('InducingCytokine', ['cytokine', 'cell_type'])
# The entity_types passed here appear to need to be exact matches on the strings
# provided above if "_" is included (must be some kind of camel case conversion oversight)
candidate_extractor = PretaggedCandidateExtractor(InducingCytokine, ['cytokine', 'cell_type'])

In [15]:
for k, sents in enumerate([train_sents, dev_sents]):
    candidate_extractor.apply(sents, split=k)
    print("Number of candidates:", session.query(InducingCytokine).filter(InducingCytokine.split == k).count())

  0%|          | 304/71130 [00:00<00:23, 3030.49it/s]

Clearing existing...
Running UDF...


100%|██████████| 71130/71130 [00:21<00:00, 3302.99it/s]
 31%|███▏      | 181/577 [00:00<00:00, 1799.01it/s]

Number of candidates: 4334
Clearing existing...
Running UDF...


100%|██████████| 577/577 [00:00<00:00, 1460.06it/s]

Number of candidates: 120





## Load Relations

In [71]:
relations = pd.read_csv(osp.join(corpus_dir, 'relations.csv'))
relations.head()

Unnamed: 0,e1_end_chr,e1_start_chr,e1_text,e1_typ,e2_end_chr,e2_start_chr,e2_text,e2_typ,id,rel_typ
0,24,19,Gfi-1,TF,85,81,Th17,CELL_TYPE,PMC2646571,Differentiation
1,44,39,TGF-β,CYTOKINE,85,81,Th17,CELL_TYPE,PMC2646571,Induction
2,44,39,TGF-β,CYTOKINE,125,90,CD103+ inducible regulatory T cells,CELL_TYPE,PMC2646571,Induction
3,24,19,Gfi-1,TF,125,90,CD103+ inducible regulatory T cells,CELL_TYPE,PMC2646571,Differentiation
4,371,366,Gfi-1,TF,436,433,Th2,CELL_TYPE,PMC2646571,Differentiation


In [None]:
# person1	person2	label
# 36c3703b-bd5b-4888-be46-2f45bcb37f8e::span:95:106	36c3703b-bd5b-4888-be46-2f45bcb37f8e::span:0:10	1
# e16a971f-23ce-42e4-81df-b2386126f8b3::span:126:134	e16a971f-23ce-42e4-81df-b2386126f8b3::span:140:157	-1

In [76]:
from snorkel.models import StableLabel
from snorkel.db_helpers import reload_annotator_labels

def get_stable_id(r):
    return '{}::span:{}:{}'.format(r['id'], r['start_chr'], r['end_chr']-1)

def load_external_labels(session, relations, candidate_class, annotator_name='gold'):
    for i, r in relations.iterrows():    

        # We check if the label already exists, in case this cell was already executed
        e1_id = get_stable_id(r.filter(regex='^e1_|^id$').rename(lambda v: v.replace('e1_', '')))
        e2_id = get_stable_id(r.filter(regex='^e2_|^id$').rename(lambda v: v.replace('e2_', '')))
        
        context_stable_ids = "~~".join([e1_id, e2_id])
        query = session.query(StableLabel)\
            .filter(StableLabel.context_stable_ids == context_stable_ids)\
            .filter(StableLabel.annotator_name == annotator_name)
        if query.count() == 0:
            session.add(StableLabel(
                context_stable_ids=context_stable_ids,
                annotator_name=annotator_name,
                value=1
            ))
            
    session.commit()
    # Reload annotator labels (* not sure why this is needed yet)
    reload_annotator_labels(
        session, candidate_class, annotator_name, split=1, 
        filter_label_split=False, create_missing_cands=False)
    
load_external_labels(session, relations[relations['rel_typ'] == 'Induction'], InducingCytokine)

AnnotatorLabels created: 24


In [78]:
# session.execute('DELETE FROM stable_label;')
# session.execute('DELETE FROM gold_label;')
# session.execute('DELETE FROM gold_label_key;')
# from snorkel.models import GoldLabel, GoldLabelKey, StableLabel
# session.query(StableLabel).count(), session.query(GoldLabel).count(), session.query(GoldLabelKey).count()

In [79]:
# stbl_ids = np.unique([
#     v
#     for r in session.query(StableLabel).all() 
#     for v in r.context_stable_ids.split('~~')
# ])
# cand_ids = np.unique([
#     cand.get_stable_id()
#     for r in session.query(InducingCytokine).all()
#     for cand in r
# ])
# len(np.intersect1d(stbl_ids, cand_ids)), len(stbl_ids), len(cand_ids)

(46, 134, 4892)

In [81]:
from snorkel.annotations import load_gold_labels

L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)

In [88]:
L_gold_dev.shape, session.query(InducingCytokine).filter(InducingCytokine.split == 1).count()

((120, 1), 120)