In [1]:
import os
import os.path as osp
import pandas as pd
import snorkel
import dotenv
from snorkel.parser import TextDocPreprocessor
from snorkel.parser import CorpusParser
from snorkel import SnorkelSession

dotenv.load_dotenv('../env.sh')
corpus_dir = osp.join(os.environ['DATA_DIR'], 'articles', 'corpus', 'corpus_00')
corpus_docs_dir = osp.join(corpus_dir, 'docs')
session = SnorkelSession()

In [2]:
doc_preprocessor = TextDocPreprocessor(corpus_docs_dir, max_docs=10)

In [3]:
tags = pd.read_csv(osp.join(corpus_dir, 'tags.csv'))
tags.head()

Unnamed: 0,id,type,ent_typ_id,ent_typ_lbl,start_chr,end_chr,start_wrd,end_wrd,text
0,PMC5743442,CELL_TYPE,4,Th17,65,69,8,9,Th17
1,PMC5743442,CELL_TYPE,4,Th17,87,91,11,12,Th17
2,PMC5743442,CYTOKINE,81,TGF-β1,174,179,25,26,TGF-β
3,PMC5743442,CELL_TYPE,4,Th17,199,203,29,30,Th17
4,PMC5743442,CYTOKINE,81,TGF-β1,275,280,42,43,TGF-β


In [4]:
from string import punctuation

def offsets_to_token(left, right, offset_array, lemmas, punc=set(punctuation)):
    """Find the token range that spans character offsets 
    
    This will find the indexes of tokens, as a range, that span a target character 
    range where the first token index has a character offset <= `left` and the 
    last token index has a character offset > `right`
    
    Example: offsets_to_token(15, 25, [0, 10, 20, 30]) --> range(1, 3)
    
    This is useful for identifying all tokens in a document that span a range
    of characters determined by another process that may have tokenized the same
    document differently.
    """
    token_start, token_end = None, None
    for i, c in enumerate(offset_array):
        if left >= c:
            token_start = i
        if c > right and token_end is None:
            token_end = i
            break
    token_end = len(offset_array) - 1 if token_end is None else token_end
    token_end = token_end - 1 if lemmas[token_end - 1] in punc else token_end
    return range(token_start, token_end)


class EntityTagger(object):

    def __init__(self, tags):   
        self.tags = tags.set_index('id')
        self.reset_stats()

    def reset_stats(self):
        self.stats = {'docs': set(), 'found': set()}
        return self
    
    def get_stats(self):
        return dict(
            n_tags=len(self.tags), 
            n_docs=len(self.stats['docs']),
            n_docs_found=len(self.stats['found']),
            pct_docs_found=100*len(self.stats['found'])/len(self.stats['docs'])
        )
    
    def tag(self, parts):
        """Tag tokens in a single sentence"""
        # Extract doc id (e.g. PMC123932) and character offsets of sentence
        docid, _, _, sent_start, sent_end = parts['stable_id'].split(':')
        self.stats['docs'].add(docid)
        if docid not in self.tags.index:
            return parts
        self.stats['found'].add(docid)
        tags = self.tags.loc[[docid]]
        sent_start, sent_end = int(sent_start), int(sent_end)
        for r in tags.itertuples():
            tag_start, tag_end = r.start_chr, r.end_chr
            # Determine whether or not the tag is in this sentence
            if not (sent_start <= tag_start <= sent_end):
                continue
            offsets = [offset + sent_start for offset in parts['char_offsets']]
            tkn_idx_rng = offsets_to_token(tag_start, tag_end, offsets, parts['lemmas'])
            for tkn_idx in tkn_idx_rng:
                parts['entity_types'][tkn_idx] = r.type.lower()
                parts['entity_cids'][tkn_idx] = r.ent_typ_lbl # TODO: use the id of the entity instead
        return parts

In [5]:
from snorkel.parser import CorpusParser, Spacy

tagger = EntityTagger(tags)
corpus_parser = CorpusParser(fn=tagger.tag)
corpus_parser.apply(list(doc_preprocessor))

  0%|          | 0/10 [00:00<?, ?it/s]

Clearing existing...
Running UDF...


100%|██████████| 10/10 [00:07<00:00,  1.11s/it]


In [6]:
tagger.get_stats()

{'n_tags': 23310, 'n_docs': 10, 'n_docs_found': 8, 'pct_docs_found': 80.0}

In [7]:
from snorkel.models import Document, Sentence

print("Documents:", session.query(Document).count())
print("Sentences:", session.query(Sentence).count())

Documents: 10
Sentences: 1675


In [8]:
docs = session.query(Document).all()
doc_ids = [doc.name for doc in docs]

In [9]:
from sklearn.model_selection import train_test_split
train_ids, test_ids = train_test_split(doc_ids, test_size=.1)
len(train_ids), len(test_ids)

(9, 1)

In [10]:
train_sents, test_sents = set(), set()
for i, doc in enumerate(docs):
    for s in doc.sentences:
        if doc.name in train_ids:
            train_sents.add(s)
        elif doc.name in test_ids:
            test_sents.add(s)
        else:
            raise Exception('ID <{0}> not found in any id set'.format(doc.name))

In [24]:
# for sent in session.query(Sentence).all():
#     if 'cytokine' in sent.entity_types and 'cell_type' in sent.entity_types:
#         print(set(sent.entity_types))

In [13]:
#sent.entity_types

In [23]:
#list(train_sents)[0].entity_types

In [32]:
from snorkel.models import Candidate, candidate_subclass
from snorkel.candidates import PretaggedCandidateExtractor

InducingCytokine = candidate_subclass('InducingCytokine', ['cytokine', 'cell_type'])
# The entity_types passed here appear to need to be exact matches on the strings
# provided above if "_" is included (must be some kind of camel case conversion oversight)
candidate_extractor = PretaggedCandidateExtractor(InducingCytokine, ['cytokine', 'cell_type'])

In [33]:
for k, sents in enumerate([train_sents, test_sents]):
    candidate_extractor.apply(sents, split=k)
    print("Number of candidates:", session.query(InducingCytokine).filter(InducingCytokine.split == k).count())

 42%|████▏     | 580/1375 [00:00<00:00, 5726.52it/s]

Clearing existing...
Running UDF...


100%|██████████| 1375/1375 [00:00<00:00, 5142.18it/s]
100%|██████████| 300/300 [00:00<00:00, 2794.49it/s]

Number of candidates: 58
Clearing existing...
Running UDF...
Number of candidates: 43



