In [5]:
import os
import os.path as osp
import pandas as pd
import numpy as np
import snorkel
import dotenv
from snorkel.parser import TextDocPreprocessor
from snorkel.parser import CorpusParser
from snorkel import SnorkelSession

dotenv.load_dotenv('../env.sh')
corpus_dir = osp.join(os.environ['DATA_DIR'], 'articles', 'corpus', 'corpus_00')
corpus_docs_dir = osp.join(corpus_dir, 'links')
#collection_dir = osp.join(os.environ['REPO_DATA_DIR'], 'brat', 'collection_01')
collection_dir = osp.join(os.environ['REPO_DATA_DIR'], 'brat', 'collection_02')
session = SnorkelSession()

In [6]:
#doc_preprocessor = TextDocPreprocessor(corpus_docs_dir, max_docs=10)
doc_preprocessor = TextDocPreprocessor(corpus_docs_dir)

In [10]:
tags = pd.read_csv(osp.join(corpus_dir, 'tags.csv'))
tags = tags[tags['type'].isin(['TCELL_TYPE', 'TRANSCRIPTION_FACTOR', 'CYTOKINE'])]
print(tags['type'].value_counts())
tags.head()

CYTOKINE                19632
TCELL_TYPE              19514
TRANSCRIPTION_FACTOR    16834
Name: type, dtype: int64


Unnamed: 0,id,type,ent_id,ent_lbl,start_chr,end_chr,start_wrd,end_wrd,text
1,PMC5743442,TCELL_TYPE,CTBFBDE5121B6748D1,Th17,65,69,8,9,Th17
2,PMC5743442,TCELL_TYPE,CTBFBDE5121B6748D1,Th17,87,91,11,12,Th17
3,PMC5743442,CYTOKINE,CK1C5A7BDC959D1365,TGF-β1,174,179,25,26,TGF-β
4,PMC5743442,TCELL_TYPE,CTBFBDE5121B6748D1,Th17,199,203,29,30,Th17
6,PMC5743442,CYTOKINE,CK1C5A7BDC959D1365,TGF-β1,275,280,42,43,TGF-β


In [11]:
def get_annotated_doc_ids():
    import glob
    return [osp.splitext(osp.basename(f))[0] for f in glob.glob(osp.join(collection_dir, '*.txt'))]
annotated_ids = get_annotated_doc_ids()
annotated_ids
{did:did in tags['id'].values for did in annotated_ids}

{'PMC4423225': True,
 'PMC3850168': True,
 'PMC4056277': True,
 'PMC3246047': True,
 'PMC4337382': True,
 'PMC6197911': True,
 'PMC4291544': True,
 'PMC5876181': True,
 'PMC6282816': True,
 'PMC5611846': True,
 'PMC5578684': True,
 'PMC3317433': False,
 'PMC2196041': True,
 'PMC5343661': True,
 'PMC5020626': True,
 'PMC4905708': True,
 'PMC4214202': True,
 'PMC4100769': False,
 'PMC3321800': True,
 'PMC5118948': True,
 'PMC5052263': True,
 'PMC2646571': True,
 'PMC3750006': True,
 'PMC4474185': True,
 'PMC5464295': True,
 'PMC5191835': True,
 'PMC5206501': True,
 'PMC3304099': True,
 'PMC2805085': True,
 'PMC4241840': True,
 'PMC4233385': True,
 'PMC3064981': True,
 'PMC5429091': True,
 'PMC3639604': True,
 'PMC5749247': False,
 'PMC5611819': True,
 'PMC3926063': True,
 'PMC6141714': True,
 'PMC6373736': False,
 'PMC3092345': True,
 'PMC200936': True,
 'PMC4959015': True,
 'PMC6122729': True,
 'PMC6372559': True,
 'PMC4084624': True,
 'PMC2634967': True,
 'PMC4592272': True,
 'PMC38421

In [5]:
from string import punctuation

def offsets_to_token(left, right, offset_array, lemmas, punc=set(punctuation)):
    token_start, token_end = None, None
    for i, c in enumerate(offset_array):
        if left >= c:
            token_start = i
        if c > right and token_end is None:
            token_end = i
            break
    token_end = len(offset_array) - 1 if token_end is None else token_end
    token_end = token_end - 1 if lemmas[token_end - 1] in punc else token_end
    return range(token_start, token_end)


class EntityTagger(object):

    def __init__(self, tags):   
        self.tags = tags.set_index('id')
        self.reset_stats()

    def reset_stats(self):
        self.stats = {'docs': set(), 'found': set()}
        return self
    
    def get_stats(self):
        return dict(
            n_tags=len(self.tags), 
            n_docs=len(self.stats['docs']),
            n_docs_found=len(self.stats['found']),
            pct_docs_found=100*len(self.stats['found'])/len(self.stats['docs'])
        )
    
    def tag(self, parts):
        """Tag tokens in a single sentence"""
        # Extract doc id (e.g. PMC123932) and character offsets of sentence
        docid, _, _, sent_start, sent_end = parts['stable_id'].split(':')
        self.stats['docs'].add(docid)
        if docid not in self.tags.index:
            return parts
        self.stats['found'].add(docid)
        tags = self.tags.loc[[docid]]
        sent_start, sent_end = int(sent_start), int(sent_end)
        for r in tags.itertuples():
            tag_start, tag_end = r.start_chr, r.end_chr
            # Determine whether or not the tag is in this sentence
            if not (sent_start <= tag_start <= sent_end):
                continue
            offsets = [offset + sent_start for offset in parts['char_offsets']]
            tkn_idx_rng = offsets_to_token(tag_start, tag_end, offsets, parts['lemmas'])
            for tkn_idx in tkn_idx_rng:
                parts['entity_types'][tkn_idx] = r.type.lower()
                parts['entity_cids'][tkn_idx] = r.ent_id
        return parts

In [6]:
from snorkel.parser import CorpusParser, Spacy

tagger = EntityTagger(tags)
corpus_parser = CorpusParser(fn=tagger.tag)
corpus_parser.apply(list(doc_preprocessor))

Clearing existing...


  0%|          | 0/510 [00:00<?, ?it/s]

Running UDF...


100%|██████████| 510/510 [05:53<00:00,  1.44it/s]


In [7]:
# This will show how many documents didn't have any tagged
# (or that were otherwise not included in tagging but included
# here in parsing -- which should be rare)
tagger.get_stats()

{'n_tags': 34923,
 'n_docs': 510,
 'n_docs_found': 440,
 'pct_docs_found': 86.27450980392157}

In [14]:
from snorkel.models import Document, Sentence

print("Documents:", session.query(Document).count())
print("Sentences:", session.query(Sentence).count())

Documents: 510
Sentences: 72944


In [15]:
docs = session.query(Document).all()

In [16]:
all_ids = [doc.name for doc in docs]
tagged_ids = list(np.intersect1d(all_ids, tags['id']))
dev_ids = list(set(annotated_ids))
train_ids = [i for i in tagged_ids if i not in dev_ids]
len(train_ids), len(dev_ids)

(430, 10)

In [17]:
train_sents, dev_sents = set(), set()
for i, doc in enumerate(docs):
    if doc.name not in tagged_ids:
        continue
    for s in doc.sentences:
        if doc.name in train_ids:
            train_sents.add(s)
        elif doc.name in dev_ids:
            dev_sents.add(s)
        else:
            raise Exception('ID <{0}> not found in any id set'.format(doc.name))

In [24]:
# for sent in session.query(Sentence).all():
#     if 'cytokine' in sent.entity_types and 'cell_type' in sent.entity_types:
#         print(set(sent.entity_types))

In [13]:
#sent.entity_types

In [18]:
from snorkel.models import Candidate, candidate_subclass
from snorkel.candidates import PretaggedCandidateExtractor

# * Make sure SecretedCytokine gives cytokine + cell type in same order as they will share rules
# for labeling functions
InducingCytokine = candidate_subclass('InducingCytokine', ['cytokine', 'cell_type'])
# The entity_types passed here appear to need to be exact matches on the strings
# provided above if "_" is included (must be some kind of camel case conversion oversight)
candidate_extractor = PretaggedCandidateExtractor(InducingCytokine, ['cytokine', 'cell_type'])

In [19]:
for k, sents in enumerate([train_sents, dev_sents]):
    candidate_extractor.apply(sents, split=k)
    print("Number of candidates:", session.query(InducingCytokine).filter(InducingCytokine.split == k).count())

  0%|          | 167/71838 [00:00<00:43, 1649.94it/s]

Clearing existing...
Running UDF...


100%|██████████| 71838/71838 [00:31<00:00, 2247.33it/s]
 18%|█▊        | 113/644 [00:00<00:00, 1104.73it/s]

Number of candidates: 8011
Clearing existing...
Running UDF...


100%|██████████| 644/644 [00:00<00:00, 1346.03it/s]

Number of candidates: 168





## Load Relations

In [20]:
relations = pd.read_csv(osp.join(corpus_dir, 'relations.csv'))
relations.head()

Unnamed: 0,e1_end_chr,e1_start_chr,e1_text,e1_typ,e2_end_chr,e2_start_chr,e2_text,e2_typ,id,rel_typ
0,24,19,Gfi-1,TF,85,81,Th17,CELL_TYPE,PMC2646571,Differentiation
1,44,39,TGF-β,CYTOKINE,85,81,Th17,CELL_TYPE,PMC2646571,Induction
2,44,39,TGF-β,CYTOKINE,125,90,CD103+ inducible regulatory T cells,CELL_TYPE,PMC2646571,Induction
3,24,19,Gfi-1,TF,125,90,CD103+ inducible regulatory T cells,CELL_TYPE,PMC2646571,Differentiation
4,371,366,Gfi-1,TF,436,433,Th2,CELL_TYPE,PMC2646571,Differentiation


In [21]:
# person1	person2	label
# 36c3703b-bd5b-4888-be46-2f45bcb37f8e::span:95:106	36c3703b-bd5b-4888-be46-2f45bcb37f8e::span:0:10	1
# e16a971f-23ce-42e4-81df-b2386126f8b3::span:126:134	e16a971f-23ce-42e4-81df-b2386126f8b3::span:140:157	-1

In [22]:
from snorkel.models import StableLabel
from snorkel.db_helpers import reload_annotator_labels

def get_stable_id(r):
    return '{}::span:{}:{}'.format(r['id'], r['start_chr'], r['end_chr']-1)

def load_external_labels(session, relations, candidate_class, annotator_name='gold'):
    for i, r in relations.iterrows():    

        # We check if the label already exists, in case this cell was already executed
        e1_id = get_stable_id(r.filter(regex='^e1_|^id$').rename(lambda v: v.replace('e1_', '')))
        e2_id = get_stable_id(r.filter(regex='^e2_|^id$').rename(lambda v: v.replace('e2_', '')))
        
        context_stable_ids = "~~".join([e1_id, e2_id])
        query = session.query(StableLabel)\
            .filter(StableLabel.context_stable_ids == context_stable_ids)\
            .filter(StableLabel.annotator_name == annotator_name)
        if query.count() == 0:
            session.add(StableLabel(
                context_stable_ids=context_stable_ids,
                annotator_name=annotator_name,
                value=1
            ))
            
    session.commit()
    reload_annotator_labels(
        session, candidate_class, annotator_name, split=1, 
        filter_label_split=False, create_missing_cands=False)
    
load_external_labels(session, relations[relations['rel_typ'] == 'Induction'], InducingCytokine)

AnnotatorLabels created: 30


In [78]:
# session.execute('DELETE FROM stable_label;')
# session.execute('DELETE FROM gold_label;')
# session.execute('DELETE FROM gold_label_key;')
# from snorkel.models import GoldLabel, GoldLabelKey, StableLabel
# session.query(StableLabel).count(), session.query(GoldLabel).count(), session.query(GoldLabelKey).count()

In [79]:
# stbl_ids = np.unique([
#     v
#     for r in session.query(StableLabel).all() 
#     for v in r.context_stable_ids.split('~~')
# ])
# cand_ids = np.unique([
#     cand.get_stable_id()
#     for r in session.query(InducingCytokine).all()
#     for cand in r
# ])
# len(np.intersect1d(stbl_ids, cand_ids)), len(stbl_ids), len(cand_ids)

(46, 134, 4892)

In [23]:
from snorkel.annotations import load_gold_labels

L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)

In [24]:
L_gold_dev.shape, session.query(InducingCytokine).filter(InducingCytokine.split == 1).count()

((168, 1), 168)