## Merge BRAT Labels w/ Snorkel

#### This notebook automates the process of merging BRAT labeled candidates with Snorkel extracted candidates to create a set of gold labels 

In [None]:
%load_ext autoreload 
%autoreload 1
import os
# os.environ['SNORKELDB'] = 'sqlite:///brat-import.db'
from set_env import *
set_env()
from snorkel import SnorkelSession
import snorkel.contrib.brat as brt
from snorkel.parsers import StanfordCoreNLPServer

from snorkel.models import Candidate, Document, candidate_subclass, GoldLabel

session = SnorkelSession()

In [None]:
corenlp_server = StanfordCoreNLPServer(version='3.6.0', split_newline=False, num_threads=1)

In [None]:
brat = brt.Brat(session)

In [None]:
# Optional cell. Populate the documents and sentences db with these brat documents.

# from snorkel.parser import TextDocPreprocessor, CorpusParser
# import multiprocessing
# doc_preprocessor = TextDocPreprocessor("brat_data/test_labeled_docs/*.txt", encoding="utf-8")
# corpus_parser = CorpusParser()
# corpus_parser.apply(doc_preprocessor, parralelism=multiprocessing.cpu_count())

### Import the BRAT labeled data 

In [None]:
input_dir = "brat_data/try_2/"
brat.import_project(input_dir, annotations_only=False, annotator_name='brat', num_threads=1, parser=corenlp_server)

In [None]:
pheno_spans = brat.explore()
print 'Total BRAT labeled phenotypes:', len(pheno_spans)
print 'Total disjoint BRAT phenotypes:', len([pheno_span for pheno_span in pheno_spans if len(pheno_span) > 1])
print 'Example BRAT phenotype spans:'
for p in pheno_spans[:5]:
    print p

### Map BRAT labels to Snorkel candidates 

In [None]:
from snorkel.models import Document, Sentence, Span

In [None]:
PhenoPairComplex = candidate_subclass('ComplexPhenotypes', ['descriptor', 'entity'])

In [None]:
import collections
from sqlalchemy import and_
from snorkel.models import StableLabel, GoldLabel, GoldLabelKey
import math
from progressbar import ProgressBar
pbar = ProgressBar()

jaccard_cutoff = 0.3

ak = session.query(GoldLabelKey).filter(GoldLabelKey.name == 'gold_complex').first()
print "Existing number of labeled candidates:", session.query(GoldLabel).filter(GoldLabel.key == ak).count()
if ak is None:
    ak = GoldLabelKey(name='gold_complex')
    session.add(ak)
    session.commit()
else:
    # Clear the labels (oh boy)
    session.query(GoldLabel).filter(GoldLabel.key == ak).delete()
    print "Number of labeled candidates (should be 0 now):", session.query(GoldLabel).filter(GoldLabel.key == ak).count()
print '================================================'
candidates = session.query(PhenoPairComplex).filter(or_(PhenoPairComplex.split == 3, PhenoPairComplex.split == 4, PhenoPairComplex.split == 5)).all()
print "Total Snorkel candidates:", len(candidates)

sentence_to_pheno = collections.defaultdict(list)

for pheno_span in pheno_spans:
    # pheno_span is a list of fragments, where each fragment is a temporary span
    if pheno_span:
        sent_id = pheno_span[0].sentence.id
        sentence_to_pheno[sent_id].append(pheno_span)

num = 0
num_matches = 0
num_sentences = 0
num_gold_phenos = 0
num_zero_matches = 0
avg = 0
num_skipped = 0
overlap = 0
num_skipped_gap = 0

labeled = {}
j=0
for sentence_id, gold_phenos in pbar(sentence_to_pheno.items()):
    sentence_match = False
    results = session.query(Span, PhenoPairComplex).filter(and_(Span.sentence_id == sentence_id, PhenoPairComplex.descriptor_id == Span.id)).all()
    if (len(results) == 0):
        num_zero_matches += 1
        
    for gold_pheno_fragments in gold_phenos:
        matched = False
        pheno_best_jaccard = 0.0
        num_matched_for_pheno = 0
        for span, cand_pheno in results:
            # Build word indice set for each of the gold phenotype, and the candiate phenotype
            cand_words = set()
            gold_words = set()
            # combine both entities (modifier and attribute) into one "phenotype"
            [cand_words.update(xrange(span.get_word_start(), span.get_word_end()+1)) for span in cand_pheno.get_contexts()]
             
            for gold_pheno in gold_pheno_fragments:
                gold_words.update(xrange(gold_pheno.get_word_start(), gold_pheno.get_word_end()+1))
            
            # Compute distance for each candidate word
            distances = []
            for word_idx in cand_words:
                if word_idx in gold_words:
                    distance = 0
                else:
                    frag_distances = []
                    for gold_pheno in gold_pheno_fragments:
                        frag_min = gold_pheno.get_word_start()
                        frag_max = gold_pheno.get_word_end()
                        if word_idx < frag_min:
                            frag_distances.append(frag_min - word_idx)
                        elif word_idx > frag_max:
                            frag_distances.append(word_idx - frag_max)
                    
                    word_distance = min(frag_distances)
                    distances.append(word_distance)
            
            total_distance = sum([word_distance ** 2 for word_distance in distances])
            
            
            intersect = gold_words.intersection(cand_words)
            jaccard_score = float(len(intersect)) / len(cand_words.union(gold_words))
            
            
            val = -1
                
            if jaccard_score > jaccard_cutoff:
                cont = False
                cont2 = False
                cont3 = False
#                 if jaccard_score > pheno_best_jaccard:
#                     pheno_best_jaccard = jaccard_score
#                 elif num_matched_for_pheno >= 2:
#                     cont3 = True
                num_words_between_gold = max(gold_words) - min(gold_words) - len(gold_words)
                num_words_between_cand = max(cand_words)-min(cand_words) - len(cand_words)
                
                if num_words_between_gold <= 1 and num_words_between_cand > 3 and total_distance > 20:
                    num_skipped_gap += 1
                    cont2 = True
                
                if total_distance > 100:
                    num_skipped += 1
                    cont = True
                
                if cont and cont2:
                    overlap +=1
                
                if not (cont or cont2 or cont3):
                    matched = True
                    sentence_match = True
                    avg += jaccard_score
                    num += 1
                    num_matched_for_pheno += 1
                    val = 1
            
            label = session.query(GoldLabel).filter(GoldLabel.key == ak).filter(GoldLabel.candidate == cand_pheno).first()
            if label is None:
                session.add(GoldLabel(candidate=cand_pheno, key=ak, value=val))
            elif label.value == -1 and val == 1:
                label.value = val
            
        num_gold_phenos += 1
        if matched:
            num_matches += 1
    if sentence_match:
        num_sentences += 1
session.commit()

num_sentences_tagged = len(sentence_to_pheno.keys())
print "Total number of BRAT tagged sentences", num_sentences_tagged
print "Total number of BRAT phenotypes", num_gold_phenos

print "overlap", overlap
print "num skipped", num_skipped
print "num_skipped_gap", num_skipped_gap
print "avg", avg / num
print "num matched", num_matches
print "num candidates labeled +1", num
recall = float(num_matches) / num_gold_phenos
print "sentences", float(num_sentences) / num_sentences_tagged
print "num sents missed", float(num_zero_matches) / num_sentences_tagged
print "recall", recall

In [None]:
# each entity/relation type is assigned to a different split

print len(candidates)
for i,c in enumerate(candidates):
    #label = session.query(GoldLabel).filter(GoldLabel.key == ak).filter(GoldLabel.candidate == c).first()
    #if label is not None and label.value == 1: print c, label
    print type(c).type, c
    if i > 5:
        break
print
    
sents = list(set([pheno_span[0].sentence.id for pheno_span in pheno_spans if pheno_span]))
candidates = [c for c in candidates if c[0].sentence.id in sents]

### split the dataset into train, dev, and test sets if not already done
#### Train - split 3; dev - split 4; test - split 5

In [None]:
sents = set([pheno_span[0].sentence.id for pheno_span in pheno_spans if pheno_span])
c = [p for p in session.query(PhenoPairComplex).filter(or_(PhenoPairComplex.split == 3, PhenoPairComplex.split == 4, PhenoPairComplex.split == 5)).all() if p and p[0].sentence.id in sents]

In [None]:
import numpy as np

doc_ids = list(set([p[0].sentence.document.id for p in session.query(PhenoPairComplex).filter(or_(PhenoPairComplex.split == 3, PhenoPairComplex.split == 4, PhenoPairComplex.split == 5)).all() if p[0].sentence.id in sents]))
print 'Total BRAT labeled documents:', len(doc_ids)
print '=============================='

#randomly select the documents to be added to the dev and tests sets
#half of the BRAT labeled documents will be the dev set and the other half will be the test set
np.random.seed(742)
split_size = len(doc_ids)/2
split_1 = np.random.choice(doc_ids, split_size, replace=False)

num_c = 0
for cand in c:
    cand.split = 4
    if cand.get_contexts()[0].sentence.document.id in split_1:
        cand.split = 5
        num_c += 1
session.commit()

print 'Total BRAT labeled phenotypes:', len(c)
print 'Total train phenotypes:', len(session.query(PhenoPairComplex).filter(PhenoPairComplex.split == 3).all())
print 'Total dev phenotypes:', len(c) - num_c
print 'Total test phenotypes:', num_c


## View gold labels 

In [None]:
from snorkel.viewer import SentenceNgramViewer
from sqlalchemy import or_

candidates = session.query(PhenoPairComplex).filter(or_(PhenoPairComplex.split == 4, PhenoPairComplex.split == 5)).all()

sv = SentenceNgramViewer(candidates, session=session, n_per_page=6, height=400,
                         annotator_name='gold_complex')

In [None]:
sv