## Candidate Generator

Steps:

- Load all documents 
    - Note: rather than streaming them, they are loaded into memory due to pyarrow segmentation faults when used in conjunction with snorkel -- most likely due to pytorch imports
   

In [1]:
# * Be careful not to import anything related to Snorkel here as it will crash pyarrow on staging step *
import os
import pandas as pd
import numpy as np
import dotenv
import shutil
import glob
import tqdm
from tcre.env import *
import os.path as osp

In [3]:
# from sqlalchemy.orm import sessionmaker
# from snorkel import SnorkelSession
# session = SnorkelSession()

## Document Staging

In [2]:
df = pd.concat([
    pd.read_parquet(osp.join(IMPORT_DATA_DIR_02, 'corpus_01.parquet')).assign(src='entrez'),
    pd.read_parquet(osp.join(IMPORT_DATA_DIR_03, 'corpus_03.parquet')).assign(src='pmcoa')
], sort=True)
assert df['id_pmc'].notnull().all()
df = df.drop_duplicates(subset=['src', 'id_pmc'])
df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 69135 entries, 0 to 48693
Data columns (total 17 columns):
abstract          67211 non-null object
arch_archive      48685 non-null object
arch_id           48685 non-null object
arch_name         48685 non-null object
arch_path         48685 non-null object
arch_venue        48685 non-null object
body              58206 non-null object
date_accepted     69135 non-null object
date_pub          69135 non-null object
date_received     69135 non-null object
id_doi            65905 non-null object
id_pmc            69135 non-null object
id_pmid           68306 non-null object
journal_ids       69135 non-null object
journal_titles    69135 non-null object
src               69135 non-null object
title             69133 non-null object
dtypes: object(17)
memory usage: 9.5+ MB


In [3]:
# Show document intersection
cts = df.assign(ind=1).pivot(index='id_pmc', columns='src', values='ind').fillna(0).astype(int)
cts.groupby(['entrez', 'pmcoa']).size().rename('count').reset_index()

Unnamed: 0,entrez,pmcoa,count
0,0,1,43824
1,1,0,15589
2,1,1,4861


In [4]:
(cts > 0).all(axis=1).value_counts()

False    59413
True      4861
dtype: int64

In [5]:
dupe_ids = cts[(cts > 0).all(axis=1)].index.values

In [6]:
# Prefer entrez doc for duplicates (make docs unique)
dff = df[(df['src'] == 'entrez') | ~df['id_pmc'].isin(dupe_ids)]
assert dff['id_pmc'].is_unique
dff.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 64274 entries, 0 to 48693
Data columns (total 17 columns):
abstract          62360 non-null object
arch_archive      43824 non-null object
arch_id           43824 non-null object
arch_name         43824 non-null object
arch_path         43824 non-null object
arch_venue        43824 non-null object
body              53365 non-null object
date_accepted     64274 non-null object
date_pub          64274 non-null object
date_received     64274 non-null object
id_doi            61191 non-null object
id_pmc            64274 non-null object
id_pmid           63447 non-null object
journal_ids       64274 non-null object
journal_titles    64274 non-null object
src               64274 non-null object
title             64272 non-null object
dtypes: object(17)
memory usage: 8.8+ MB


In [7]:
def stage_docs(df, batch_size=500, staging_dir='/tmp/corpus_staging'):
    if osp.exists(staging_dir):
        shutil.rmtree(staging_dir)
    os.makedirs(staging_dir)

    n_batches = len(df) // batch_size
    batches = np.array_split(np.arange(len(df)), n_batches)
    for i, batch in tqdm.tqdm(enumerate(batches), total=len(batches)):
        path = osp.join(staging_dir, 'B{:05d}.feather'.format(i))
        df.iloc[batch].reset_index(drop=True).to_feather(path)
stage_docs(dff)

100%|██████████| 128/128 [00:08<00:00, 15.07it/s]


### Define Jobs To Process

In [2]:
from snorkel import SnorkelSession
from snorkel.parser import CorpusParser
from snorkel.models import Document, Sentence, Candidate
from dask.distributed import Client, progress
from tcre import processing
from tcre import supervision
import dask
import logging

In [3]:
def get_jobs(staging_dir='/tmp/corpus_staging'):
    jobs = []
    for f in os.listdir(staging_dir):
        i = int(f.split('.')[0][1:])
        jobs.append((i, osp.join(staging_dir, f)))
    return sorted(jobs)
jobs = get_jobs()
len(jobs)

128

In [4]:
# jobs = [jobs[i] for i in [  
#     0,   1,   2,   3,   4,   5,   6,   8,   9,  11,  12,  13,  14,
# ]]
# len(jobs)

### Clear Existing Contexts

In [4]:
def get_doc_ct():
    session = SnorkelSession()
    ct = session.query(Document).count()
    session.close()
    return ct
    
def clear_documents():
    session = SnorkelSession()
    parser = CorpusParser(parser=lambda v: None)
    parser.clear(session)
    session.commit()
    session.close()
    
n = get_doc_ct()
clear_documents()
print('Num docs before =', n, 'after =', get_doc_ct())

Num docs before = 0 after = 0


## Document Import

In [5]:
client = Client(
    threads_per_worker=1, n_workers=12, processes=True, 
    memory_limit='256GB', direct_to_workers=True, 
    silence_logs=logging.INFO
)
client

distributed.scheduler - INFO - Clear task state
distributed.scheduler - INFO -   Scheduler at:     tcp://127.0.0.1:42609
distributed.scheduler - INFO - Clear task state
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:34281'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:43737'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:38343'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:34605'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:44437'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:39241'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:41679'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:34697'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:37409'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:40481'
distributed.nanny - INFO -         Start Nanny at: 'tcp://172.17.0.2:32

0,1
Client  Scheduler: tcp://127.0.0.1:42609,Cluster  Workers: 12  Cores: 12  Memory: 3.07 TB


In [6]:
def process(job):
    import logging
    logging.basicConfig()
    batch_id, batch_file = job
    logging.info('Processing job %s (%s)', batch_id, batch_file)
    loader = processing.DocLoader(batch_file)
    loader.run(limit=None)

In [7]:
futures = client.map(process, jobs)
dask.distributed.progress(futures, notebook=True)

VBox()

In [8]:
dask.distributed.wait(futures)

distributed.scheduler - INFO - Remove worker tcp://172.17.0.2:32809
distributed.core - INFO - Removing comms to tcp://172.17.0.2:32809
distributed.scheduler - INFO - Remove worker tcp://172.17.0.2:33585
distributed.core - INFO - Removing comms to tcp://172.17.0.2:33585
distributed.scheduler - INFO - Remove worker tcp://172.17.0.2:45321
distributed.core - INFO - Removing comms to tcp://172.17.0.2:45321
distributed.scheduler - INFO - Remove worker tcp://172.17.0.2:44893
distributed.core - INFO - Removing comms to tcp://172.17.0.2:44893
distributed.scheduler - INFO - Remove worker tcp://172.17.0.2:37581
distributed.core - INFO - Removing comms to tcp://172.17.0.2:37581
distributed.batched - INFO - Batched Comm Closed: in <closed TCP>: Stream is closed
distributed.batched - INFO - Batched Comm Closed: in <closed TCP>: Stream is closed
distributed.batched - INFO - Batched Comm Closed: in <closed TCP>: Stream is closed
distributed.scheduler - INFO - Remove worker tcp://172.17.0.2:35005
distr

KeyboardInterrupt: 

In [None]:
assert all([f.status == 'finished' for f in futures])

In [24]:
np.array([f.status == 'finished' for f in futures]).sum()

28

In [26]:
np.argwhere([f.status != 'finished' for f in futures]).squeeze()

array([  0,   1,   2,   3,   4,   5,   6,   8,   9,  11,  12,  13,  14,
        15,  16,  17,  18,  20,  21,  23,  25,  26,  27,  29,  30,  32,
        33,  35,  36,  37,  38,  39,  41,  42,  44,  45,  47,  48,  49,
        51,  52,  54,  56,  57,  59,  60,  61,  63,  64,  68,  69,  72,
        74,  75,  76,  77,  78,  81,  82,  83,  84,  85,  87,  88,  89,
        90,  91,  92,  93,  96,  97,  98,  99, 100, 101, 102, 103, 104,
       105, 106, 107, 108, 109, 110, 111, 112, 113, 115, 116, 117, 118,
       119, 120, 121, 122, 123, 124, 125, 126, 127])

## Candidate Generation

In [21]:
session = SnorkelSession()

In [None]:
session.query(Document).count()

In [42]:
sents = session.query(Sentence).all()
len(sents)

3197

In [43]:
sents = set(sents)
len(sents)

3197

In [36]:
from snorkel.candidates import PretaggedCandidateExtractor
classes = supervision.get_candidate_classes()
candidate_extractors = [
    PretaggedCandidateExtractor(c.subclass, c.entity_types)
    for c in classes.values()
]

In [46]:
def apply_extraction(sents, split, batch_size=10000):
    for extractor in candidate_extractors:
        relation_class = extractor.udf_init_kwargs['candidate_class']
        n_batch = int(np.ceil(len(sents) / batch_size))
        print('Beginning candidate extraction for split {}, relation type {}, num batches {}'.format(
            split, relation_class.__name__, n_batch
        ))
        for batch in tqdm.tqdm(np.array_split(list(sents), n_batch)):
            extractor.apply(batch, split=split, clear=False, progress_bar=False)
        print('Number of candidates generated for split {}, relation type {} = {}'.format(
            split, relation_class.__name__,
            session.query(relation_class).filter(relation_class.split == split).count()
        ))


apply_extraction(sents, supervision.SPLIT_INFER)

  0%|          | 0/1 [00:00<?, ?it/s]

Beginning candidate extraction for split 9, relation type InducingCytokine, num batches 1
Running UDF...


100%|██████████| 1/1 [00:04<00:00,  4.88s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

Number of candidates generated for split 9, relation type InducingCytokine = 1356
Beginning candidate extraction for split 9, relation type SecretedCytokine, num batches 1
Running UDF...


100%|██████████| 1/1 [00:03<00:00,  3.28s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

Number of candidates generated for split 9, relation type SecretedCytokine = 1356
Beginning candidate extraction for split 9, relation type InducingTranscriptionFactor, num batches 1
Running UDF...


100%|██████████| 1/1 [00:01<00:00,  1.68s/it]

Number of candidates generated for split 9, relation type InducingTranscriptionFactor = 291





In [54]:
cands = session.query(Candidate.type, Candidate.split).all()
pd.DataFrame(cands).groupby(['type', 'split']).size()

type                           split
inducing_cytokine              9        1356
inducing_transcription_factor  9         291
secreted_cytokine              9        1356
dtype: int64

## Load Gold Labels

Only relevant for ```candidate_mode == 'training'```

In [None]:
if candidate_mode != 'training':
    raise ValueError('Loading manual annotations only relevant for training data')

In [15]:
# Read csv export with annotated relations to load:
relations = pd.read_csv(osp.join(corpus_dir, 'relations.csv'))
relations.head()

Unnamed: 0,e1_end_chr,e1_start_chr,e1_text,e1_typ,e2_end_chr,e2_start_chr,e2_text,e2_typ,id,rel_typ
0,24,19,Gfi-1,TRANSCRIPTION_FACTOR,85,81,Th17,IMMUNE_CELL_TYPE,PMC2646571,Differentiation
1,44,39,TGF-β,CYTOKINE,85,81,Th17,IMMUNE_CELL_TYPE,PMC2646571,Induction
2,44,39,TGF-β,CYTOKINE,119,97,inducible regulatory T,IMMUNE_CELL_TYPE,PMC2646571,Induction
3,24,19,Gfi-1,TRANSCRIPTION_FACTOR,119,97,inducible regulatory T,IMMUNE_CELL_TYPE,PMC2646571,Differentiation
4,371,366,Gfi-1,TRANSCRIPTION_FACTOR,436,433,Th2,IMMUNE_CELL_TYPE,PMC2646571,Differentiation


In [16]:
relations['rel_typ'].value_counts()

Induction          150
Secretion          131
Differentiation    119
Name: rel_typ, dtype: int64

In [17]:
relations.groupby(['e1_typ', 'e2_typ', 'rel_typ']).size()

e1_typ                e2_typ            rel_typ        
CYTOKINE              IMMUNE_CELL_TYPE  Induction          150
                                        Secretion          131
TRANSCRIPTION_FACTOR  IMMUNE_CELL_TYPE  Differentiation    119
dtype: int64

In [145]:
# # Reset annotation tables (if loading them below fails)
# session.execute('DELETE FROM stable_label;')
# session.execute('DELETE FROM gold_label;')
# session.execute('DELETE FROM gold_label_key;')
# from snorkel.models import GoldLabel, GoldLabelKey, StableLabel
# session.commit()
# session.query(StableLabel).count(), session.query(GoldLabel).count(), session.query(GoldLabelKey).count()

In [18]:
from snorkel.models import StableLabel
from snorkel.db_helpers import reload_annotator_labels

def get_stable_id(r):
    return '{}::span:{}:{}'.format(r['id'], r['start_chr'], r['end_chr']-1)

def reload_labels(
    session, candidate_class, annotator_name, split, 
    filter_label_split=True, create_missing_cands=False):
    """Reloads stable annotator labels into the AnnotatorLabel table"""
    from snorkel.models import GoldLabel, GoldLabelKey, StableLabel, Context
    from future.utils import iteritems
    # Sets up the AnnotatorLabelKey to use
    ak = session.query(GoldLabelKey).filter(GoldLabelKey.name == annotator_name).first()
    if ak is None:
        ak = GoldLabelKey(name=annotator_name)
        session.add(ak)
        session.commit()

    labels = []
    missed = []
    sl_query = session.query(StableLabel).filter(StableLabel.annotator_name == annotator_name)
    sl_query = sl_query.filter(StableLabel.split == split) if filter_label_split else sl_query
    for sl in sl_query.all():
        context_stable_ids = sl.context_stable_ids.split('~~')

        # Check for labeled Contexts
        # TODO: Does not create the Contexts if they do not yet exist!
        contexts = []
        for stable_id in context_stable_ids:
            context = session.query(Context).filter(Context.stable_id == stable_id).first()
            if context:
                contexts.append(context)
        if len(contexts) < len(context_stable_ids):
            missed.append(sl)
            continue

        # Check for Candidate
        # Assemble candidate arguments
        candidate_args  = {'split' : split}
        for i, arg_name in enumerate(candidate_class.__argnames__):
            candidate_args[arg_name] = contexts[i]

        # Assemble query and check
        candidate_query = session.query(candidate_class)
        for k, v in iteritems(candidate_args):
            candidate_query = candidate_query.filter(getattr(candidate_class, k) == v)
        candidate = candidate_query.first()

        # Optionally construct missing candidates
        if candidate is None and create_missing_cands:
            candidate = candidate_class(**candidate_args)

        # If candidate is none, mark as missed and continue
        if candidate is None:
            missed.append(sl)
            continue

        # Check for AnnotatorLabel, otherwise create
        label = session.query(GoldLabel).filter(GoldLabel.key == ak).filter(GoldLabel.candidate == candidate).first()
        if label is None:
            label = GoldLabel(candidate=candidate, key=ak, value=sl.value)
            session.add(label)
            labels.append(label)

    session.commit()
    print("AnnotatorLabels created: %s, missed: %s" % (len(labels), len(missed)))
    return missed, labels
    
def load_external_labels(session, relations, candidate_class, annotator_name='gold'):
    print(annotator_name)
    for i, r in relations.iterrows():    

        # We check if the label already exists, in case this cell was already executed
        e1_id = get_stable_id(r.filter(regex='^e1_|^id$').rename(lambda v: v.replace('e1_', '')))
        e2_id = get_stable_id(r.filter(regex='^e2_|^id$').rename(lambda v: v.replace('e2_', '')))
        
        context_stable_ids = "~~".join([e1_id, e2_id])
        query = session.query(StableLabel)\
            .filter(StableLabel.context_stable_ids == context_stable_ids)\
            .filter(StableLabel.annotator_name == annotator_name)
        if query.count() == 0:
            session.add(StableLabel(
                context_stable_ids=context_stable_ids,
                annotator_name=annotator_name,
                value=1
            ))
            
    session.commit()
    # This function will create GoldLabel records for each StableLabel above after
    # selecting them based on annotator_name.  The annotator name should be different
    # for each candidate class if they might have identical context_stable_ids since
    # otherwise as written above only the first value for the same annotator + context_stable_ids
    # will be saved.
    # Other notes: split will be used to find candidates necessary to create GoldLabels though
    # it is not necessary for StableLabel filtering in this case (thus filter_label_split=False)
    # because the labels were not created with a split above
    #reload_annotator_labels(
    return reload_labels(
        session, candidate_class, annotator_name, split=1, 
        filter_label_split=False, create_missing_cands=False)
    
cand_summary = {}
for extractor in candidate_extractors:
    relation_class = extractor.udf_init_kwargs['candidate_class']
    label = classes[relation_class.__name__].label
    field = classes[relation_class.__name__].field
    df = relations[relations['rel_typ'] == label]
    assert len(df) > 0, 'Found no records for relation type {}'.format(label)
    print('Found {} relations for type {}'.format(len(df), label))
    cand_summary[relation_class.__name__] = load_external_labels(session, df, relation_class, annotator_name=field)

Found 150 relations for type Induction
inducing_cytokine
AnnotatorLabels created: 133, missed: 17
Found 131 relations for type Secretion
secreted_cytokine
AnnotatorLabels created: 81, missed: 50
Found 119 relations for type Differentiation
inducing_transcription_factor
AnnotatorLabels created: 86, missed: 33


In [19]:
from snorkel.models import Context 

# Show which entity types and relations were unable to be matched with spans
# extracted and inserted into snorkel db
def summarize_missing_candidates(cand_summary):
    df = []
    for c in cand_summary:
        class_name = classes[c].name
        missed = cand_summary[c][0]
        for mc in missed:
            doc_id = mc.context_stable_ids.split('::')[0]
            ids = mc.context_stable_ids.split('~~')
            typs = []
            for i, sid in enumerate(ids):
                ctx = session.query(Context).filter(Context.stable_id == sid).all()
                if len(ctx) == 0:
                    typs.append(classes[c].entity_types[i])
            df.append((class_name, doc_id, ','.join(typs)))
    return pd.DataFrame(df, columns=['relation', 'doc_id', 'missing'])
df_miss = summarize_missing_candidates(cand_summary)
df_miss.groupby(['relation', 'missing']).size()

relation                     missing                              
InducingCytokine             cytokine                                  7
                             immune_cell_type                         10
InducingTranscriptionFactor                                            1
                             immune_cell_type                          7
                             transcription_factor                     18
                             transcription_factor,immune_cell_type     7
SecretedCytokine             cytokine                                 11
                             cytokine,immune_cell_type                 3
                             immune_cell_type                         36
dtype: int64

In [20]:
# Find documents with most annotations unable to be matched and either improve tagging or change annotations
df_miss.groupby(['relation', 'doc_id']).size().sort_values(ascending=False).head(15)

relation                     doc_id    
SecretedCytokine             PMC3046151    13
                             PMC2196041    11
InducingTranscriptionFactor  PMC2587175     8
                             PMC2783637     6
                             PMC2646571     5
SecretedCytokine             PMC2193209     5
InducingCytokine             PMC3173465     4
SecretedCytokine             PMC4385920     4
InducingTranscriptionFactor  PMC3173465     4
SecretedCytokine             PMC3650071     4
InducingTranscriptionFactor  PMC3304099     3
                             PMC5591438     2
InducingCytokine             PMC4023883     2
InducingTranscriptionFactor  PMC5206501     2
InducingCytokine             PMC2196041     2
dtype: int64

In [21]:
from snorkel.annotations import load_gold_labels, load_matrix
from snorkel.models import Candidate

L_gold = {}
for c in classes:
    cids_query = get_cids_query(session, classes[c], split=1)
    L_gold[c] = load_gold_labels(session, annotator_name=classes[c].field, split=1, cids_query=cids_query)

In [22]:
for c in L_gold:
    print(c, L_gold[c].shape)

InducingCytokine (673, 1)
SecretedCytokine (673, 1)
InducingTranscriptionFactor (410, 1)


In [23]:
for c in classes.values():
    for split in [0, 1]:
        n = session.query(c.subclass).filter(c.subclass.split == split).count()
        print('Candidate counts: {} (split {}) -> {}'.format(c.name, split, n))

Candidate counts: InducingCytokine (split 0) -> 11735
Candidate counts: InducingCytokine (split 1) -> 673
Candidate counts: SecretedCytokine (split 0) -> 11735
Candidate counts: SecretedCytokine (split 1) -> 673
Candidate counts: InducingTranscriptionFactor (split 0) -> 6696
Candidate counts: InducingTranscriptionFactor (split 1) -> 410


In [24]:
# Make sure that inducing/secreted cytokine labels are mutally exclusive
L_df = pd.DataFrame(np.hstack((
    L_gold[classes.inducing_cytokine.name].toarray(), 
    L_gold[classes.secreted_cytokine.name].toarray()
)))
L_df.groupby([0, 1]).size()

0  1
0  0    459
   1     81
1  0    133
dtype: int64

In [25]:
pd.Series([
    type(L_gold[classes.inducing_cytokine.name].get_candidate(session, i)).__name__ 
    for i in range(L_gold[classes.inducing_cytokine.name].shape[0])
]).value_counts()

InducingCytokine    673
dtype: int64