In [1]:
import os
import os.path as osp
import pandas as pd
import numpy as np
import snorkel
import dotenv
import glob
import tqdm
from snorkel.parser import TextDocPreprocessor
from snorkel.parser import CorpusParser
from snorkel.models import Candidate, Document, Sentence
from snorkel import SnorkelSession

dotenv.load_dotenv('../env.sh')
%run ../src/supervision.py
collection_dir = osp.join(os.environ['REPO_DATA_DIR'], 'brat', 'collection_02')
session = SnorkelSession()

# Choose whether candidates will be loaded as a train/dev/test (i.e. 0/1/2) split or as a single inference split (3)
#corpus_dir = osp.join(os.environ['DATA_DIR'], 'articles', 'corpus', 'corpus_00')
#candidate_mode='training'

corpus_dir = osp.join(os.environ['DATA_DIR'], 'articles', 'corpus', 'corpus_01')
candidate_mode='inference'

corpus_docs_dir = osp.join(corpus_dir, 'links')
# rm ~/repos/hammer/t-cell-relation-extraction/pm_subtype_protein_relations/snorkel/snorkel.db

In [2]:
# Clear existing candidates, if need be
ct = session.query(Candidate).count()
# if ct > 0:
#     # Clear all existing candidates (don't let extractors do it)
#     # See: https://github.com/HazyResearch/snorkel/blob/master/snorkel/candidates.py#L47
#     # *This seems to always cause a database lock somehow -- perhaps it needs to be done with autocommit but
#     # for now a workaround is to do this at the beginning and the restart the kernel
#     from snorkel.models import Candidate
#     ndelete = session.query(Candidate).delete()
#     session.commit()
#     # Restart kernel
ct

170683

In [3]:
tags = pd.read_csv(osp.join(corpus_dir, 'tags.csv'))
tags = tags[tags['type'].isin(ENT_TYPES)]
print(tags['type'].value_counts())
tags.head()

CYTOKINE                484597
IMMUNE_CELL_TYPE        314554
TRANSCRIPTION_FACTOR    245369
Name: type, dtype: int64


Unnamed: 0,id,type,ent_id,ent_lbl,ent_prefid,start_chr,end_chr,start_wrd,end_wrd,text
1,PMC5704053,IMMUNE_CELL_TYPE,CT30BC86BDEF7B1410,Treg,CTB574584AD019ABB8,26,38,7,9,Regulatory T
2,PMC5704053,IMMUNE_CELL_TYPE,CT30BC86BDEF7B1410,Treg,CTB574584AD019ABB8,125,137,29,31,regulatory T
3,PMC5704053,IMMUNE_CELL_TYPE,CTB574584AD019ABB8,Treg,CTB574584AD019ABB8,145,149,33,34,Treg
4,PMC5704053,IMMUNE_CELL_TYPE,CTB574584AD019ABB8,Treg,CTB574584AD019ABB8,327,331,63,64,Treg
8,PMC5704053,CYTOKINE,CK379C94E0D2330772,IL-10,CKC5CC1A269C01EC48,468,473,87,88,IL-10


In [4]:
from string import punctuation

def offsets_to_token(left, right, offset_array, lemmas, punc=set(punctuation)):
    token_start, token_end = None, None
    for i, c in enumerate(offset_array):
        if left >= c:
            token_start = i
        if c > right and token_end is None:
            token_end = i
            break
    token_end = len(offset_array) - 1 if token_end is None else token_end
    token_end = token_end - 1 if lemmas[token_end - 1] in punc else token_end
    return range(token_start, token_end)


class EntityTagger(object):

    def __init__(self, tags):   
        self.tags = tags.set_index('id')
        self.reset_stats()

    def reset_stats(self):
        self.stats = {'docs': set(), 'found': set()}
        return self
    
    def get_stats(self):
        return dict(
            n_tags=len(self.tags), 
            n_docs=len(self.stats['docs']),
            n_docs_found=len(self.stats['found']),
            pct_docs_found=100*len(self.stats['found'])/len(self.stats['docs'])
        )
    
    def tag(self, parts):
        """Tag tokens in a single sentence"""
        # Extract doc id (e.g. PMC123932) and character offsets of sentence
        docid, _, _, sent_start, sent_end = parts['stable_id'].split(':')
        self.stats['docs'].add(docid)
        if docid not in self.tags.index:
            return parts
        self.stats['found'].add(docid)
        tags = self.tags.loc[[docid]]
        sent_start, sent_end = int(sent_start), int(sent_end)
        for r in tags.itertuples():
            tag_start, tag_end = r.start_chr, r.end_chr
            # Determine whether or not the tag is in this sentence
            if not (sent_start <= tag_start <= sent_end):
                continue
            offsets = [offset + sent_start for offset in parts['char_offsets']]
            tkn_idx_rng = offsets_to_token(tag_start, tag_end, offsets, parts['lemmas'])
            for tkn_idx in tkn_idx_rng:
                parts['entity_types'][tkn_idx] = r.type.lower()
                parts['entity_cids'][tkn_idx] = r.ent_id + ':' + r.ent_prefid
        return parts

In [5]:
def get_id_from_file(f):
    return osp.splitext(osp.basename(f))[0]

def get_dir_doc_ids(path):
    return set([get_id_from_file(f) for f in glob.glob(osp.join(path, '*.txt'))])

dev_ids = get_dir_doc_ids(collection_dir)
corpus_00_ids = get_dir_doc_ids(osp.join(os.environ['DATA_DIR'], 'articles', 'corpus', 'corpus_00', 'links'))
corpus_01_ids = get_dir_doc_ids(osp.join(os.environ['DATA_DIR'], 'articles', 'corpus', 'corpus_01', 'links'))

train_ids = corpus_00_ids.difference(dev_ids)
inference_ids = corpus_01_ids.difference(corpus_00_ids)
inserted_ids = set([r[0] for r in session.query(Document.name).all()])

# Show frequency of docs that are annotated AND have tags of some kind
#pd.Series({doc_id:doc_id in tags['id'].values for doc_id in dev_ids}).value_counts()
assert len(dev_ids.intersection(train_ids)) == 0
assert len(dev_ids.intersection(inference_ids)) == 0
len(dev_ids), len(train_ids), len(inference_ids), len(inserted_ids)

(89, 487, 9727, 6603)

In [6]:
all_corpus_files = glob.glob(osp.join(corpus_docs_dir, '*.txt'))
corpus_files = [f for f in all_corpus_files if get_id_from_file(f) not in inserted_ids]
len(all_corpus_files), len(corpus_files)

(10043, 3700)

In [7]:
from snorkel.parser import TextDocPreprocessor

class DocListProcessor(TextDocPreprocessor):
    
    def __init__(self, paths, encoding="utf-8"):
        super().__init__(None, encoding=encoding)
        self.paths = paths
                    
    def _get_files(self, path):
        return self.paths
    
#doc_preprocessor = TextDocPreprocessor(corpus_docs_dir, max_docs=10)

In [None]:
batch_size = 100
tagger = EntityTagger(tags)
corpus_parser = CorpusParser(fn=tagger.tag)

# Process in batches since Document/Sentence objects are buffered into memory
# until .commit is called (which is only done at end of UDF .apply)
for batch in tqdm.tqdm(np.array_split(corpus_files, len(corpus_files)//batch_size)):
    doc_preprocessor = DocListProcessor(list(batch))
    corpus_parser.apply(list(doc_preprocessor), clear=False)

  0%|          | 0/37 [00:00<?, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s][A

Running UDF...



  2%|▏         | 2/100 [00:09<07:59,  4.89s/it][A
  3%|▎         | 3/100 [00:09<05:37,  3.48s/it][A
  4%|▍         | 4/100 [00:14<06:14,  3.90s/it][A
  5%|▌         | 5/100 [00:25<09:31,  6.01s/it][A
  7%|▋         | 7/100 [00:31<07:46,  5.01s/it][A
  8%|▊         | 8/100 [00:34<06:48,  4.44s/it][A
  9%|▉         | 9/100 [00:39<07:12,  4.76s/it][A
 10%|█         | 10/100 [00:39<05:04,  3.38s/it][A
 11%|█         | 11/100 [00:45<06:00,  4.05s/it][A
 12%|█▏        | 12/100 [00:52<07:12,  4.91s/it][A
 13%|█▎        | 13/100 [00:52<05:04,  3.50s/it][A
 14%|█▍        | 14/100 [00:59<06:35,  4.60s/it][A
 15%|█▌        | 15/100 [01:05<07:08,  5.04s/it][A
 16%|█▌        | 16/100 [01:13<08:19,  5.95s/it][A
 17%|█▋        | 17/100 [01:16<07:00,  5.06s/it][A
 18%|█▊        | 18/100 [01:17<04:55,  3.60s/it][A
 19%|█▉        | 19/100 [01:24<06:12,  4.60s/it][A
 20%|██        | 20/100 [01:29<06:23,  4.79s/it][A
 21%|██        | 21/100 [01:29<04:29,  3.41s/it][A
 22%|██▏       | 2

Running UDF...



  2%|▏         | 2/100 [00:05<04:09,  2.54s/it][A
  3%|▎         | 3/100 [00:14<07:16,  4.50s/it][A
  4%|▍         | 4/100 [00:14<05:09,  3.22s/it][A
  5%|▌         | 5/100 [00:21<06:50,  4.33s/it][A
  6%|▌         | 6/100 [00:29<08:37,  5.50s/it][A
  7%|▋         | 7/100 [00:35<08:46,  5.66s/it][A
  9%|▉         | 9/100 [00:35<06:03,  3.99s/it][A
 10%|█         | 10/100 [00:40<06:07,  4.08s/it][A
 11%|█         | 11/100 [00:40<04:17,  2.90s/it][A
 12%|█▏        | 12/100 [00:47<06:16,  4.28s/it][A
 13%|█▎        | 13/100 [00:52<06:22,  4.39s/it][A
 14%|█▍        | 14/100 [00:56<06:07,  4.28s/it][A
 15%|█▌        | 15/100 [01:02<06:43,  4.75s/it][A
 16%|█▌        | 16/100 [01:09<07:35,  5.43s/it][A
 17%|█▋        | 17/100 [01:09<05:19,  3.85s/it][A
 18%|█▊        | 18/100 [01:14<05:40,  4.15s/it][A
 19%|█▉        | 19/100 [01:18<05:48,  4.31s/it][A
 20%|██        | 20/100 [01:24<06:19,  4.74s/it][A
 21%|██        | 21/100 [01:28<05:56,  4.51s/it][A
 22%|██▏       | 2

Running UDF...



  2%|▏         | 2/100 [00:00<00:11,  8.50it/s][A
  3%|▎         | 3/100 [00:06<02:59,  1.85s/it][A
  4%|▍         | 4/100 [00:10<04:03,  2.54s/it][A
  5%|▌         | 5/100 [00:14<04:56,  3.12s/it][A
  7%|▋         | 7/100 [00:14<03:25,  2.21s/it][A
  9%|▉         | 9/100 [00:21<03:45,  2.48s/it][A
 10%|█         | 10/100 [00:27<05:20,  3.56s/it][A
 11%|█         | 11/100 [00:31<05:47,  3.91s/it][A
 12%|█▏        | 12/100 [00:36<06:13,  4.24s/it][A
 13%|█▎        | 13/100 [00:41<06:12,  4.28s/it][A
 14%|█▍        | 14/100 [00:47<07:02,  4.92s/it][A
 15%|█▌        | 15/100 [00:47<04:57,  3.50s/it][A
 16%|█▌        | 16/100 [00:52<05:16,  3.76s/it][A
 17%|█▋        | 17/100 [00:57<05:46,  4.18s/it][A
 18%|█▊        | 18/100 [01:02<06:04,  4.44s/it][A
 19%|█▉        | 19/100 [01:02<04:16,  3.17s/it][A
 20%|██        | 20/100 [01:02<03:02,  2.28s/it][A
 21%|██        | 21/100 [01:03<02:11,  1.66s/it][A
 23%|██▎       | 23/100 [01:10<02:50,  2.21s/it][A
 24%|██▍       | 

Running UDF...



  3%|▎         | 3/100 [00:04<02:25,  1.50s/it][A
  4%|▍         | 4/100 [00:09<03:59,  2.49s/it][A
  5%|▌         | 5/100 [00:09<02:51,  1.80s/it][A
  6%|▌         | 6/100 [00:14<04:11,  2.68s/it][A
  8%|▊         | 8/100 [00:19<03:53,  2.53s/it][A
  9%|▉         | 9/100 [00:23<04:49,  3.18s/it][A
 11%|█         | 11/100 [00:23<03:19,  2.25s/it][A
 12%|█▏        | 12/100 [00:24<02:24,  1.64s/it][A
 14%|█▍        | 14/100 [00:32<03:31,  2.46s/it][A
 15%|█▌        | 15/100 [00:37<04:36,  3.25s/it][A
 16%|█▌        | 16/100 [00:42<05:03,  3.61s/it][A
 17%|█▋        | 17/100 [00:48<06:03,  4.37s/it][A
 19%|█▉        | 19/100 [00:52<04:50,  3.59s/it][A
 20%|██        | 20/100 [00:59<06:21,  4.77s/it][A
 21%|██        | 21/100 [01:03<06:03,  4.60s/it][A
 22%|██▏       | 22/100 [01:11<07:13,  5.56s/it][A
 23%|██▎       | 23/100 [01:19<08:13,  6.40s/it][A
 24%|██▍       | 24/100 [01:27<08:25,  6.66s/it][A
 26%|██▌       | 26/100 [01:33<06:58,  5.66s/it][A
 27%|██▋       | 

Running UDF...



  2%|▏         | 2/100 [00:05<04:45,  2.91s/it][A
  3%|▎         | 3/100 [00:11<06:16,  3.88s/it][A
  4%|▍         | 4/100 [00:18<07:14,  4.53s/it][A
  6%|▌         | 6/100 [00:20<05:33,  3.54s/it][A
  7%|▋         | 7/100 [00:24<05:47,  3.74s/it][A
  8%|▊         | 8/100 [00:29<06:13,  4.06s/it][A
  9%|▉         | 9/100 [00:37<08:09,  5.37s/it][A
 10%|█         | 10/100 [00:46<09:26,  6.30s/it][A
 11%|█         | 11/100 [00:51<08:41,  5.86s/it][A
 12%|█▏        | 12/100 [00:56<08:19,  5.67s/it][A
 13%|█▎        | 13/100 [01:02<08:11,  5.65s/it][A
 14%|█▍        | 14/100 [01:02<05:46,  4.03s/it][A
 15%|█▌        | 15/100 [01:05<05:32,  3.92s/it][A
 16%|█▌        | 16/100 [01:06<03:54,  2.79s/it][A
 17%|█▋        | 17/100 [01:11<05:00,  3.62s/it][A
 19%|█▉        | 19/100 [01:22<05:33,  4.11s/it][A
 21%|██        | 21/100 [01:22<03:50,  2.92s/it][A
 22%|██▏       | 22/100 [01:28<04:52,  3.75s/it][A
 24%|██▍       | 24/100 [01:35<04:37,  3.65s/it][A
 25%|██▌       | 2

Running UDF...



  2%|▏         | 2/100 [00:05<04:44,  2.90s/it][A
  3%|▎         | 3/100 [00:12<06:39,  4.12s/it][A
  4%|▍         | 4/100 [00:17<06:39,  4.16s/it][A
  5%|▌         | 5/100 [00:23<07:38,  4.82s/it][A

In [None]:
# from snorkel.parser import CorpusParser, Spacy

# tagger = EntityTagger(tags)
# corpus_parser = CorpusParser(fn=tagger.tag)
# corpus_parser.apply(list(doc_preprocessor), clear=False)

In [9]:
# This will show how many documents didn't have any tagged
# (or that were otherwise not included in tagging but included
# here in parsing -- which should be rare)
tagger.get_stats()

{'n_tags': 1044520,
 'n_docs': 3452,
 'n_docs_found': 3215,
 'pct_docs_found': 93.13441483198146}

In [10]:
print("Documents:", session.query(Document).count())
print("Sentences:", session.query(Sentence).count())

Documents: 6603
Sentences: 1192160


In [11]:
docs = session.query(Document).all()

In [18]:
all_ids = set([doc.name for doc in docs])
tagged_ids = set(np.intersect1d(list(all_ids), tags['id']))
len(all_ids), len(tagged_ids)

(6603, 5894)

In [25]:
train_sents, dev_sents, inference_sents = [], [], []
for i, doc in enumerate(docs):
    if doc.name not in tagged_ids:
        continue
    for s in doc.sentences:
        if doc.name in train_ids:
            train_sents.append(s)
        elif doc.name in dev_ids:
            dev_sents.append(s)
        elif doc.name in inference_ids:
            inference_sents.append(s)
        else:
            raise Exception('ID <{0}> not found in any id set'.format(doc.name))
train_sents, dev_sents, inference_sents = set(train_sents), set(dev_sents), set(inference_sents)

In [26]:
from snorkel.candidates import PretaggedCandidateExtractor
classes = get_candidate_classes()
candidate_extractors = [
    PretaggedCandidateExtractor(c.subclass, c.entity_types)
    for c in classes.values()
]

In [46]:
# for extractor in candidate_extractors:
#     extractor.clear(session, split=SPLIT_INFER)
# session.commit()

In [27]:
def apply_extraction(sents, split, batch_size=10000):
    for extractor in candidate_extractors:
        relation_class = extractor.udf_init_kwargs['candidate_class']
        n_batch = len(sents)//batch_size
        print('Beginning candidate extraction for split {}, relation type {}, num batches {}'.format(
            split, relation_class.__name__, n_batch
        ))
        for batch in tqdm.tqdm(np.array_split(list(sents), n_batch)):
            extractor.apply(batch, split=split, clear=False, progress_bar=False)
        print('Number of candidates generated for split {}, relation type {} = {}'.format(
            split, relation_class.__name__,
            session.query(relation_class).filter(relation_class.split == split).count()
        ))

if candidate_mode == 'training':
    for split, sents in enumerate([train_sents, dev_sents]):
        apply_extraction(sents, split)
elif candidate_mode == 'inference':
    apply_extraction(inference_sents, SPLIT_INFER)
else:
    raise ValueError('Candidate mode "{}" not valid'.format(candidate_mode))
    

Beginning candidate extraction for split 3, relation type InducingCytokine, num batches 111





  0%|          | 0/111 [00:00<?, ?it/s][A[A[A

Running UDF...





  1%|          | 1/111 [00:02<05:08,  2.81s/it][A[A[A

Running UDF...





  2%|▏         | 2/111 [00:05<05:02,  2.77s/it][A[A[A

Running UDF...





  3%|▎         | 3/111 [00:08<04:53,  2.72s/it][A[A[A

Running UDF...





  4%|▎         | 4/111 [00:10<04:45,  2.66s/it][A[A[A

Running UDF...





  5%|▍         | 5/111 [00:13<04:44,  2.68s/it][A[A[A

Running UDF...





  5%|▌         | 6/111 [00:15<04:39,  2.66s/it][A[A[A

Running UDF...





  6%|▋         | 7/111 [00:18<04:30,  2.60s/it][A[A[A

Running UDF...





  7%|▋         | 8/111 [00:20<04:22,  2.55s/it][A[A[A

Running UDF...





  8%|▊         | 9/111 [00:23<04:21,  2.57s/it][A[A[A

Running UDF...





  9%|▉         | 10/111 [00:26<04:28,  2.65s/it][A[A[A

Running UDF...





 10%|▉         | 11/111 [00:28<04:23,  2.63s/it][A[A[A

Running UDF...





 11%|█         | 12/111 [00:31<04:20,  2.63s/it][A[A[A

Running UDF...





 12%|█▏        | 13/111 [00:34<04:22,  2.68s/it][A[A[A

Running UDF...





 13%|█▎        | 14/111 [00:37<04:21,  2.70s/it][A[A[A

Running UDF...





 14%|█▎        | 15/111 [00:39<04:22,  2.73s/it][A[A[A

Running UDF...





 14%|█▍        | 16/111 [00:42<04:22,  2.76s/it][A[A[A

Running UDF...





 15%|█▌        | 17/111 [00:45<04:23,  2.81s/it][A[A[A

Running UDF...





 16%|█▌        | 18/111 [00:48<04:14,  2.73s/it][A[A[A

Running UDF...





 17%|█▋        | 19/111 [00:50<04:08,  2.70s/it][A[A[A

Running UDF...





 18%|█▊        | 20/111 [00:53<04:01,  2.66s/it][A[A[A

Running UDF...





 19%|█▉        | 21/111 [00:55<03:55,  2.61s/it][A[A[A

Running UDF...





 20%|█▉        | 22/111 [00:58<03:47,  2.56s/it][A[A[A

Running UDF...





 21%|██        | 23/111 [01:00<03:48,  2.60s/it][A[A[A

Running UDF...





 22%|██▏       | 24/111 [01:03<03:42,  2.55s/it][A[A[A

Running UDF...





 23%|██▎       | 25/111 [01:05<03:38,  2.54s/it][A[A[A

Running UDF...





 23%|██▎       | 26/111 [01:08<03:30,  2.47s/it][A[A[A

Running UDF...





 24%|██▍       | 27/111 [01:10<03:25,  2.45s/it][A[A[A

Running UDF...





 25%|██▌       | 28/111 [01:13<03:25,  2.48s/it][A[A[A

Running UDF...





 26%|██▌       | 29/111 [01:15<03:30,  2.56s/it][A[A[A

Running UDF...





 27%|██▋       | 30/111 [01:18<03:33,  2.64s/it][A[A[A

Running UDF...





 28%|██▊       | 31/111 [01:21<03:35,  2.70s/it][A[A[A

Running UDF...





 29%|██▉       | 32/111 [01:24<03:32,  2.69s/it][A[A[A

Running UDF...





 30%|██▉       | 33/111 [01:27<03:34,  2.75s/it][A[A[A

Running UDF...





 31%|███       | 34/111 [01:29<03:27,  2.69s/it][A[A[A

Running UDF...





 32%|███▏      | 35/111 [01:32<03:25,  2.70s/it][A[A[A

Running UDF...





 32%|███▏      | 36/111 [01:34<03:16,  2.61s/it][A[A[A

Running UDF...





 33%|███▎      | 37/111 [01:37<03:18,  2.68s/it][A[A[A

Running UDF...





 34%|███▍      | 38/111 [01:40<03:15,  2.68s/it][A[A[A

Running UDF...





 35%|███▌      | 39/111 [01:43<03:14,  2.70s/it][A[A[A

Running UDF...





 36%|███▌      | 40/111 [01:45<03:10,  2.69s/it][A[A[A

Running UDF...





 37%|███▋      | 41/111 [01:48<03:04,  2.63s/it][A[A[A

Running UDF...





 38%|███▊      | 42/111 [01:50<02:56,  2.56s/it][A[A[A

Running UDF...





 39%|███▊      | 43/111 [01:52<02:49,  2.49s/it][A[A[A

Running UDF...





 40%|███▉      | 44/111 [01:55<02:45,  2.47s/it][A[A[A

Running UDF...





 41%|████      | 45/111 [01:57<02:43,  2.47s/it][A[A[A

Running UDF...





 41%|████▏     | 46/111 [02:00<02:40,  2.46s/it][A[A[A

Running UDF...





 42%|████▏     | 47/111 [02:02<02:37,  2.47s/it][A[A[A

Running UDF...





 43%|████▎     | 48/111 [02:05<02:34,  2.46s/it][A[A[A

Running UDF...





 44%|████▍     | 49/111 [02:07<02:31,  2.44s/it][A[A[A

Running UDF...





 45%|████▌     | 50/111 [02:09<02:23,  2.36s/it][A[A[A

Running UDF...





 46%|████▌     | 51/111 [02:12<02:22,  2.38s/it][A[A[A

Running UDF...





 47%|████▋     | 52/111 [02:14<02:23,  2.43s/it][A[A[A

Running UDF...





 48%|████▊     | 53/111 [02:17<02:23,  2.47s/it][A[A[A

Running UDF...





 49%|████▊     | 54/111 [02:19<02:20,  2.46s/it][A[A[A

Running UDF...





 50%|████▉     | 55/111 [02:22<02:15,  2.41s/it][A[A[A

Running UDF...





 50%|█████     | 56/111 [02:24<02:11,  2.39s/it][A[A[A

Running UDF...





 51%|█████▏    | 57/111 [02:26<02:11,  2.44s/it][A[A[A

Running UDF...





 52%|█████▏    | 58/111 [02:29<02:07,  2.41s/it][A[A[A

Running UDF...





 53%|█████▎    | 59/111 [02:31<02:06,  2.43s/it][A[A[A

Running UDF...





 54%|█████▍    | 60/111 [02:34<02:03,  2.42s/it][A[A[A

Running UDF...





 55%|█████▍    | 61/111 [02:36<02:00,  2.41s/it][A[A[A

Running UDF...


OperationalError: (sqlite3.OperationalError) database is locked
[SQL: INSERT INTO context (type, stable_id) VALUES (?, ?)]
[parameters: ('span', 'PMC4666629::span:28003:28007')]
(Background on this error at: http://sqlalche.me/e/e3q8)

## Load Gold Labels

Only relevant for ```candidate_mode == 'training'```

In [15]:
# Read csv export with annotated relations to load:
relations = pd.read_csv(osp.join(corpus_dir, 'relations.csv'))
relations.head()

Unnamed: 0,e1_end_chr,e1_start_chr,e1_text,e1_typ,e2_end_chr,e2_start_chr,e2_text,e2_typ,id,rel_typ
0,24,19,Gfi-1,TRANSCRIPTION_FACTOR,85,81,Th17,IMMUNE_CELL_TYPE,PMC2646571,Differentiation
1,44,39,TGF-β,CYTOKINE,85,81,Th17,IMMUNE_CELL_TYPE,PMC2646571,Induction
2,44,39,TGF-β,CYTOKINE,119,97,inducible regulatory T,IMMUNE_CELL_TYPE,PMC2646571,Induction
3,24,19,Gfi-1,TRANSCRIPTION_FACTOR,119,97,inducible regulatory T,IMMUNE_CELL_TYPE,PMC2646571,Differentiation
4,371,366,Gfi-1,TRANSCRIPTION_FACTOR,436,433,Th2,IMMUNE_CELL_TYPE,PMC2646571,Differentiation


In [16]:
relations['rel_typ'].value_counts()

Induction          150
Secretion          131
Differentiation    119
Name: rel_typ, dtype: int64

In [17]:
relations.groupby(['e1_typ', 'e2_typ', 'rel_typ']).size()

e1_typ                e2_typ            rel_typ        
CYTOKINE              IMMUNE_CELL_TYPE  Induction          150
                                        Secretion          131
TRANSCRIPTION_FACTOR  IMMUNE_CELL_TYPE  Differentiation    119
dtype: int64

In [145]:
# # Reset annotation tables (if loading them below fails)
# session.execute('DELETE FROM stable_label;')
# session.execute('DELETE FROM gold_label;')
# session.execute('DELETE FROM gold_label_key;')
# from snorkel.models import GoldLabel, GoldLabelKey, StableLabel
# session.commit()
# session.query(StableLabel).count(), session.query(GoldLabel).count(), session.query(GoldLabelKey).count()

In [18]:
from snorkel.models import StableLabel
from snorkel.db_helpers import reload_annotator_labels

def get_stable_id(r):
    return '{}::span:{}:{}'.format(r['id'], r['start_chr'], r['end_chr']-1)

def reload_labels(
    session, candidate_class, annotator_name, split, 
    filter_label_split=True, create_missing_cands=False):
    """Reloads stable annotator labels into the AnnotatorLabel table"""
    from snorkel.models import GoldLabel, GoldLabelKey, StableLabel, Context
    from future.utils import iteritems
    # Sets up the AnnotatorLabelKey to use
    ak = session.query(GoldLabelKey).filter(GoldLabelKey.name == annotator_name).first()
    if ak is None:
        ak = GoldLabelKey(name=annotator_name)
        session.add(ak)
        session.commit()

    labels = []
    missed = []
    sl_query = session.query(StableLabel).filter(StableLabel.annotator_name == annotator_name)
    sl_query = sl_query.filter(StableLabel.split == split) if filter_label_split else sl_query
    for sl in sl_query.all():
        context_stable_ids = sl.context_stable_ids.split('~~')

        # Check for labeled Contexts
        # TODO: Does not create the Contexts if they do not yet exist!
        contexts = []
        for stable_id in context_stable_ids:
            context = session.query(Context).filter(Context.stable_id == stable_id).first()
            if context:
                contexts.append(context)
        if len(contexts) < len(context_stable_ids):
            missed.append(sl)
            continue

        # Check for Candidate
        # Assemble candidate arguments
        candidate_args  = {'split' : split}
        for i, arg_name in enumerate(candidate_class.__argnames__):
            candidate_args[arg_name] = contexts[i]

        # Assemble query and check
        candidate_query = session.query(candidate_class)
        for k, v in iteritems(candidate_args):
            candidate_query = candidate_query.filter(getattr(candidate_class, k) == v)
        candidate = candidate_query.first()

        # Optionally construct missing candidates
        if candidate is None and create_missing_cands:
            candidate = candidate_class(**candidate_args)

        # If candidate is none, mark as missed and continue
        if candidate is None:
            missed.append(sl)
            continue

        # Check for AnnotatorLabel, otherwise create
        label = session.query(GoldLabel).filter(GoldLabel.key == ak).filter(GoldLabel.candidate == candidate).first()
        if label is None:
            label = GoldLabel(candidate=candidate, key=ak, value=sl.value)
            session.add(label)
            labels.append(label)

    session.commit()
    print("AnnotatorLabels created: %s, missed: %s" % (len(labels), len(missed)))
    return missed, labels
    
def load_external_labels(session, relations, candidate_class, annotator_name='gold'):
    print(annotator_name)
    for i, r in relations.iterrows():    

        # We check if the label already exists, in case this cell was already executed
        e1_id = get_stable_id(r.filter(regex='^e1_|^id$').rename(lambda v: v.replace('e1_', '')))
        e2_id = get_stable_id(r.filter(regex='^e2_|^id$').rename(lambda v: v.replace('e2_', '')))
        
        context_stable_ids = "~~".join([e1_id, e2_id])
        query = session.query(StableLabel)\
            .filter(StableLabel.context_stable_ids == context_stable_ids)\
            .filter(StableLabel.annotator_name == annotator_name)
        if query.count() == 0:
            session.add(StableLabel(
                context_stable_ids=context_stable_ids,
                annotator_name=annotator_name,
                value=1
            ))
            
    session.commit()
    # This function will create GoldLabel records for each StableLabel above after
    # selecting them based on annotator_name.  The annotator name should be different
    # for each candidate class if they might have identical context_stable_ids since
    # otherwise as written above only the first value for the same annotator + context_stable_ids
    # will be saved.
    # Other notes: split will be used to find candidates necessary to create GoldLabels though
    # it is not necessary for StableLabel filtering in this case (thus filter_label_split=False)
    # because the labels were not created with a split above
    #reload_annotator_labels(
    return reload_labels(
        session, candidate_class, annotator_name, split=1, 
        filter_label_split=False, create_missing_cands=False)
    
cand_summary = {}
for extractor in candidate_extractors:
    relation_class = extractor.udf_init_kwargs['candidate_class']
    label = classes[relation_class.__name__].label
    field = classes[relation_class.__name__].field
    df = relations[relations['rel_typ'] == label]
    assert len(df) > 0, 'Found no records for relation type {}'.format(label)
    print('Found {} relations for type {}'.format(len(df), label))
    cand_summary[relation_class.__name__] = load_external_labels(session, df, relation_class, annotator_name=field)

Found 150 relations for type Induction
inducing_cytokine
AnnotatorLabels created: 133, missed: 17
Found 131 relations for type Secretion
secreted_cytokine
AnnotatorLabels created: 81, missed: 50
Found 119 relations for type Differentiation
inducing_transcription_factor
AnnotatorLabels created: 86, missed: 33


In [19]:
from snorkel.models import Context 

# Show which entity types and relations were unable to be matched with spans
# extracted and inserted into snorkel db
def summarize_missing_candidates(cand_summary):
    df = []
    for c in cand_summary:
        class_name = classes[c].name
        missed = cand_summary[c][0]
        for mc in missed:
            doc_id = mc.context_stable_ids.split('::')[0]
            ids = mc.context_stable_ids.split('~~')
            typs = []
            for i, sid in enumerate(ids):
                ctx = session.query(Context).filter(Context.stable_id == sid).all()
                if len(ctx) == 0:
                    typs.append(classes[c].entity_types[i])
            df.append((class_name, doc_id, ','.join(typs)))
    return pd.DataFrame(df, columns=['relation', 'doc_id', 'missing'])
df_miss = summarize_missing_candidates(cand_summary)
df_miss.groupby(['relation', 'missing']).size()

relation                     missing                              
InducingCytokine             cytokine                                  7
                             immune_cell_type                         10
InducingTranscriptionFactor                                            1
                             immune_cell_type                          7
                             transcription_factor                     18
                             transcription_factor,immune_cell_type     7
SecretedCytokine             cytokine                                 11
                             cytokine,immune_cell_type                 3
                             immune_cell_type                         36
dtype: int64

In [20]:
# Find documents with most annotations unable to be matched and either improve tagging or change annotations
df_miss.groupby(['relation', 'doc_id']).size().sort_values(ascending=False).head(15)

relation                     doc_id    
SecretedCytokine             PMC3046151    13
                             PMC2196041    11
InducingTranscriptionFactor  PMC2587175     8
                             PMC2783637     6
                             PMC2646571     5
SecretedCytokine             PMC2193209     5
InducingCytokine             PMC3173465     4
SecretedCytokine             PMC4385920     4
InducingTranscriptionFactor  PMC3173465     4
SecretedCytokine             PMC3650071     4
InducingTranscriptionFactor  PMC3304099     3
                             PMC5591438     2
InducingCytokine             PMC4023883     2
InducingTranscriptionFactor  PMC5206501     2
InducingCytokine             PMC2196041     2
dtype: int64

In [21]:
from snorkel.annotations import load_gold_labels, load_matrix
from snorkel.models import Candidate

L_gold = {}
for c in classes:
    cids_query = get_cids_query(session, classes[c], split=1)
    L_gold[c] = load_gold_labels(session, annotator_name=classes[c].field, split=1, cids_query=cids_query)

In [22]:
for c in L_gold:
    print(c, L_gold[c].shape)

InducingCytokine (673, 1)
SecretedCytokine (673, 1)
InducingTranscriptionFactor (410, 1)


In [23]:
for c in classes.values():
    for split in [0, 1]:
        n = session.query(c.subclass).filter(c.subclass.split == split).count()
        print('Candidate counts: {} (split {}) -> {}'.format(c.name, split, n))

Candidate counts: InducingCytokine (split 0) -> 11735
Candidate counts: InducingCytokine (split 1) -> 673
Candidate counts: SecretedCytokine (split 0) -> 11735
Candidate counts: SecretedCytokine (split 1) -> 673
Candidate counts: InducingTranscriptionFactor (split 0) -> 6696
Candidate counts: InducingTranscriptionFactor (split 1) -> 410


In [24]:
# Make sure that inducing/secreted cytokine labels are mutally exclusive
L_df = pd.DataFrame(np.hstack((
    L_gold[classes.inducing_cytokine.name].toarray(), 
    L_gold[classes.secreted_cytokine.name].toarray()
)))
L_df.groupby([0, 1]).size()

0  1
0  0    459
   1     81
1  0    133
dtype: int64

In [25]:
pd.Series([
    type(L_gold[classes.inducing_cytokine.name].get_candidate(session, i)).__name__ 
    for i in range(L_gold[classes.inducing_cytokine.name].shape[0])
]).value_counts()

InducingCytokine    673
dtype: int64