In [1]:
import os
import os.path as osp
import pandas as pd
import numpy as np
import snorkel
import dotenv
from snorkel.parser import TextDocPreprocessor
from snorkel.parser import CorpusParser
from snorkel import SnorkelSession

dotenv.load_dotenv('../env.sh')
%run ../src/supervision.py
corpus_dir = osp.join(os.environ['DATA_DIR'], 'articles', 'corpus', 'corpus_00')
corpus_docs_dir = osp.join(corpus_dir, 'links')
#collection_dir = osp.join(os.environ['REPO_DATA_DIR'], 'brat', 'collection_01')
collection_dir = osp.join(os.environ['REPO_DATA_DIR'], 'brat', 'collection_02')
session = SnorkelSession()

In [None]:
# Clear existing candidates, if need be

from snorkel.models import Candidate
ct = session.query(Candidate).count()
# if ct > 0:
#     # Clear all existing candidates (don't let extractors do it)
#     # See: https://github.com/HazyResearch/snorkel/blob/master/snorkel/candidates.py#L47
#     # *This seems to always cause a database lock somehow -- perhaps it needs to be done with autocommit but
#     # for now a workaround is to do this at the beginning and the restart the kernel
#     from snorkel.models import Candidate
#     ndelete = session.query(Candidate).delete()
#     session.commit()
#     # Restart kernel
ct

In [7]:
tags = pd.read_csv(osp.join(corpus_dir, 'tags.csv'))
tags = tags[tags['type'].isin(ENT_TYPES)]
print(tags['type'].value_counts())
tags.head()

CYTOKINE                22440
IMMUNE_CELL_TYPE        21023
TRANSCRIPTION_FACTOR    16806
Name: type, dtype: int64


Unnamed: 0,id,type,ent_id,ent_lbl,ent_prefid,start_chr,end_chr,start_wrd,end_wrd,text
1,PMC5743442,IMMUNE_CELL_TYPE,CTBFBDE5121B6748D1,Th17,CTBFBDE5121B6748D1,65,69,8,9,Th17
2,PMC5743442,IMMUNE_CELL_TYPE,CTBFBDE5121B6748D1,Th17,CTBFBDE5121B6748D1,87,91,11,12,Th17
3,PMC5743442,CYTOKINE,CKFD4CA0B2B4BC3AE4,TGF-β,CKFD4CA0B2B4BC3AE4,174,179,25,26,TGF-β
4,PMC5743442,IMMUNE_CELL_TYPE,CTBFBDE5121B6748D1,Th17,CTBFBDE5121B6748D1,199,203,29,30,Th17
6,PMC5743442,CYTOKINE,CKC59F8990F767EBD4,TGF-β,CKC59F8990F767EBD4,275,280,42,43,TGF-β


In [None]:
#doc_preprocessor = TextDocPreprocessor(corpus_docs_dir, max_docs=10)
doc_preprocessor = TextDocPreprocessor(corpus_docs_dir)

In [9]:
def get_annotated_doc_ids():
    import glob
    return [osp.splitext(osp.basename(f))[0] for f in glob.glob(osp.join(collection_dir, '*.txt'))]
annotated_ids = get_annotated_doc_ids()
# Show frequency of docs that are annotated AND have tags of some kind
pd.Series({did:did in tags['id'].values for did in annotated_ids}).value_counts()

True     102
False      7
dtype: int64

In [9]:
from string import punctuation

def offsets_to_token(left, right, offset_array, lemmas, punc=set(punctuation)):
    token_start, token_end = None, None
    for i, c in enumerate(offset_array):
        if left >= c:
            token_start = i
        if c > right and token_end is None:
            token_end = i
            break
    token_end = len(offset_array) - 1 if token_end is None else token_end
    token_end = token_end - 1 if lemmas[token_end - 1] in punc else token_end
    return range(token_start, token_end)


class EntityTagger(object):

    def __init__(self, tags):   
        self.tags = tags.set_index('id')
        self.reset_stats()

    def reset_stats(self):
        self.stats = {'docs': set(), 'found': set()}
        return self
    
    def get_stats(self):
        return dict(
            n_tags=len(self.tags), 
            n_docs=len(self.stats['docs']),
            n_docs_found=len(self.stats['found']),
            pct_docs_found=100*len(self.stats['found'])/len(self.stats['docs'])
        )
    
    def tag(self, parts):
        """Tag tokens in a single sentence"""
        # Extract doc id (e.g. PMC123932) and character offsets of sentence
        docid, _, _, sent_start, sent_end = parts['stable_id'].split(':')
        self.stats['docs'].add(docid)
        if docid not in self.tags.index:
            return parts
        self.stats['found'].add(docid)
        tags = self.tags.loc[[docid]]
        sent_start, sent_end = int(sent_start), int(sent_end)
        for r in tags.itertuples():
            tag_start, tag_end = r.start_chr, r.end_chr
            # Determine whether or not the tag is in this sentence
            if not (sent_start <= tag_start <= sent_end):
                continue
            offsets = [offset + sent_start for offset in parts['char_offsets']]
            tkn_idx_rng = offsets_to_token(tag_start, tag_end, offsets, parts['lemmas'])
            for tkn_idx in tkn_idx_rng:
                parts['entity_types'][tkn_idx] = r.type.lower()
                parts['entity_cids'][tkn_idx] = r.ent_id + ':' + r.ent_prefid
        return parts

In [10]:
from snorkel.parser import CorpusParser, Spacy

tagger = EntityTagger(tags)
corpus_parser = CorpusParser(fn=tagger.tag)
corpus_parser.apply(list(doc_preprocessor))

Clearing existing...


  0%|          | 0/595 [00:00<?, ?it/s]

Running UDF...


100%|██████████| 595/595 [06:35<00:00,  1.51it/s]


In [11]:
# This will show how many documents didn't have any tagged
# (or that were otherwise not included in tagging but included
# here in parsing -- which should be rare)
tagger.get_stats()

{'n_tags': 60269,
 'n_docs': 595,
 'n_docs_found': 556,
 'pct_docs_found': 93.4453781512605}

In [4]:
from snorkel.models import Document, Sentence

print("Documents:", session.query(Document).count())
print("Sentences:", session.query(Sentence).count())

Documents: 595
Sentences: 72547


In [5]:
docs = session.query(Document).all()

In [10]:
all_ids = [doc.name for doc in docs]
tagged_ids = list(np.intersect1d(all_ids, tags['id']))
dev_ids = list(set(annotated_ids))
train_ids = [i for i in tagged_ids if i not in dev_ids]
len(train_ids), len(dev_ids)

(454, 109)

In [11]:
train_sents, dev_sents = set(), set()
for i, doc in enumerate(docs):
    if doc.name not in tagged_ids:
        continue
    for s in doc.sentences:
        if doc.name in train_ids:
            train_sents.add(s)
        elif doc.name in dev_ids:
            dev_sents.add(s)
        else:
            raise Exception('ID <{0}> not found in any id set'.format(doc.name))

In [16]:
# for sent in session.query(Sentence).all():
#     if 'cytokine' in sent.entity_types and 'cell_type' in sent.entity_types:
#         print(set(sent.entity_types))

In [12]:
from snorkel.candidates import PretaggedCandidateExtractor
classes = get_relation_classes()

# The entity_types passed here appear to need to be exact matches on the strings
# provided above if "_" is included (must be some kind of camel case conversion oversight)
#candidate_extractor = PretaggedCandidateExtractor(InducingCytokine, ['cytokine', 'cell_type'])
candidate_extractors = [
    PretaggedCandidateExtractor(classes.inducing_cytokine_class, classes.inducing_cytokine_types),
    PretaggedCandidateExtractor(classes.secreted_cytokine_class, classes.secreted_cytokine_types),
    PretaggedCandidateExtractor(classes.inducing_transcription_factor_class, classes.inducing_transcription_factor_types),
]

In [14]:
for k, sents in enumerate([train_sents, dev_sents]):
    for extractor in candidate_extractors:
        relation_class = extractor.udf_init_kwargs['candidate_class']
        extractor.apply(sents, split=k, clear=False)
        print('Number of candidates generated for split {}, relation type {}: {}'.format(
            k, relation_class.__name__,
            session.query(relation_class).filter(relation_class.split == k).count()
        ))

  0%|          | 144/69941 [00:00<00:51, 1349.38it/s]

Running UDF...


100%|██████████| 69941/69941 [00:32<00:00, 2127.07it/s]
  0%|          | 259/69941 [00:00<00:30, 2258.16it/s]

Number of candidates generated for split 0, relation type InducingCytokine: 11817
Running UDF...


100%|██████████| 69941/69941 [00:27<00:00, 2551.44it/s]
  0%|          | 259/69941 [00:00<00:29, 2351.08it/s]

Number of candidates generated for split 0, relation type SecretedCytokine: 11817
Running UDF...


100%|██████████| 69941/69941 [00:22<00:00, 3062.05it/s]
  6%|▌         | 131/2320 [00:00<00:01, 1288.68it/s]

Number of candidates generated for split 0, relation type InducingTranscriptionFactor: 8462
Running UDF...


100%|██████████| 2320/2320 [00:01<00:00, 1718.82it/s]
  6%|▌         | 144/2320 [00:00<00:01, 1407.87it/s]

Number of candidates generated for split 1, relation type InducingCytokine: 662
Running UDF...


100%|██████████| 2320/2320 [00:01<00:00, 1963.41it/s]
 12%|█▏        | 285/2320 [00:00<00:00, 2830.72it/s]

Number of candidates generated for split 1, relation type SecretedCytokine: 662
Running UDF...


100%|██████████| 2320/2320 [00:00<00:00, 2431.39it/s]

Number of candidates generated for split 1, relation type InducingTranscriptionFactor: 455





## Load Relations

In [15]:
relations = pd.read_csv(osp.join(corpus_dir, 'relations.csv'))
relations.head()

Unnamed: 0,e1_end_chr,e1_start_chr,e1_text,e1_typ,e2_end_chr,e2_start_chr,e2_text,e2_typ,id,rel_typ
0,24,19,Gfi-1,TRANSCRIPTION_FACTOR,85,81,Th17,IMMUNE_CELL_TYPE,PMC2646571,Differentiation
1,44,39,TGF-β,CYTOKINE,85,81,Th17,IMMUNE_CELL_TYPE,PMC2646571,Induction
2,44,39,TGF-β,CYTOKINE,125,90,CD103+ inducible regulatory T cells,IMMUNE_CELL_TYPE,PMC2646571,Induction
3,24,19,Gfi-1,TRANSCRIPTION_FACTOR,125,90,CD103+ inducible regulatory T cells,IMMUNE_CELL_TYPE,PMC2646571,Differentiation
4,371,366,Gfi-1,TRANSCRIPTION_FACTOR,436,433,Th2,IMMUNE_CELL_TYPE,PMC2646571,Differentiation


In [17]:
relations['rel_typ'].value_counts()

Induction          116
Secretion          113
Differentiation    101
Name: rel_typ, dtype: int64

In [16]:
relations.groupby(['e1_typ', 'e2_typ']).size()

e1_typ                e2_typ          
CYTOKINE              IMMUNE_CELL_TYPE    116
IMMUNE_CELL_TYPE      CYTOKINE            113
TRANSCRIPTION_FACTOR  IMMUNE_CELL_TYPE    101
dtype: int64

In [21]:
# person1	person2	label
# 36c3703b-bd5b-4888-be46-2f45bcb37f8e::span:95:106	36c3703b-bd5b-4888-be46-2f45bcb37f8e::span:0:10	1
# e16a971f-23ce-42e4-81df-b2386126f8b3::span:126:134	e16a971f-23ce-42e4-81df-b2386126f8b3::span:140:157	-1

In [44]:
from snorkel.models import StableLabel
from snorkel.db_helpers import reload_annotator_labels

def get_stable_id(r):
    return '{}::span:{}:{}'.format(r['id'], r['start_chr'], r['end_chr']-1)

def load_external_labels(session, relations, candidate_class, annotator_name='gold'):
    for i, r in relations.iterrows():    

        # We check if the label already exists, in case this cell was already executed
        e1_id = get_stable_id(r.filter(regex='^e1_|^id$').rename(lambda v: v.replace('e1_', '')))
        e2_id = get_stable_id(r.filter(regex='^e2_|^id$').rename(lambda v: v.replace('e2_', '')))
        
        context_stable_ids = "~~".join([e1_id, e2_id])
        query = session.query(StableLabel)\
            .filter(StableLabel.context_stable_ids == context_stable_ids)\
            .filter(StableLabel.annotator_name == annotator_name)
        if query.count() == 0:
            session.add(StableLabel(
                context_stable_ids=context_stable_ids,
                annotator_name=annotator_name,
                value=1
            ))
            
    session.commit()
    reload_annotator_labels(
        session, candidate_class, annotator_name, split=1, 
        filter_label_split=False, create_missing_cands=False)
    
for extractor in candidate_extractors:
    relation_class = extractor.udf_init_kwargs['candidate_class']
    label = RELATION_CLASSES[relation_class.__name__]['label']
    df = relations[relations['rel_typ'] == label]
    print('Found {} relations for type {}'.format(len(df), label))
    load_external_labels(session, df, relation_class)

Found 116 relations for type Induction
AnnotatorLabels created: 99
Found 113 relations for type Secretion
AnnotatorLabels created: 99
Found 101 relations for type Differentiation
AnnotatorLabels created: 59


In [43]:
session.execute('DELETE FROM stable_label;')
session.execute('DELETE FROM gold_label;')
session.execute('DELETE FROM gold_label_key;')
from snorkel.models import GoldLabel, GoldLabelKey, StableLabel
session.commit()
session.query(StableLabel).count(), session.query(GoldLabel).count(), session.query(GoldLabelKey).count()

(0, 0, 0)

In [45]:
# stbl_ids = np.unique([
#     v
#     for r in session.query(StableLabel).all() 
#     for v in r.context_stable_ids.split('~~')
# ])
# cand_ids = np.unique([
#     cand.get_stable_id()
#     for r in session.query(InducingCytokine).all()
#     for cand in r
# ])
# len(np.intersect1d(stbl_ids, cand_ids)), len(stbl_ids), len(cand_ids)

In [46]:
from snorkel.annotations import load_gold_labels

L_gold_dev = load_gold_labels(session, annotator_name='gold', split=1)

In [57]:
type(L_gold_dev), L_gold_dev.shape

(snorkel.annotations.csr_LabelMatrix, (1779, 1))

In [55]:
pd.Series([type(L_gold_dev.get_candidate(session, i)).__name__ for i in range(L_gold_dev.shape[0])]).value_counts()

SecretedCytokine               662
InducingCytokine               662
InducingTranscriptionFactor    455
dtype: int64

In [56]:
#L_gold_dev.shape, session.query(InducingCytokine).filter(InducingCytokine.split == 1).count()