# Snorkel Candidate Extraction

Note: do not launch jupyter notebook directly, use ```sh run.sh``` from snorkle repo

https://github.com/UW-Deepdive-Infrastructure/stromatolites_snorkel/

https://github.com/HazyResearch/snorkel

http://hazyresearch.github.io/snorkel/blog/weak_supervision.html

In [1]:
%load_ext autoreload
%autoreload 2

import os
os.environ['SNORKELDB']="postgres://postgres@localhost:5432/snorkel_postgres"

from snorkel import SnorkelSession
session = SnorkelSession()

In [2]:
from snorkel.models import Sentence, Document, candidate_subclass
from snorkel.candidates import Ngrams, CandidateExtractor
from snorkel.matchers import LocationMatcher, DictionaryMatch, RegexMatchSpan
from snorkel.viewer import SentenceNgramViewer

import os

sentences = session.query(Sentence).all()
len(sentences)

147287

In [3]:
print("Documents:", session.query(Document).count())
print("Sentences:", session.query(Sentence).count())

('Documents:', 254L)
('Sentences:', 147287L)


## Candidate Extraction

In [4]:
ngram_rm = Ngrams(n_max=1)
ngram_names = Ngrams(n_max=9)

In [5]:
'''
Creating a Dam schema of relation mention to extract
Creates a database table if not available 
'''
DamCandidates = candidate_subclass('DamCandidates', ['removal_term', 'dam_name'])

In [6]:
dam_rm_matcher = RegexMatchSpan(rgx='removal|remove|destroy|breach|removed|breached|removing|post-dam|demolition|demolish|demolished|razing|razed|raze')

In [7]:
'''
Input: strings related to "removal", list of current removed dam names
Dam Names: https://github.com/bserna-usgs/DRD_DeepDive/blob/master/input/removedDams20151214.csv
'''
with open('./input/data/all_dams.csv', 'r') as f:
    dam_data = [i.split(',')[0] for i in f.readlines()]  # list of dam names

dam_name_matcher = DictionaryMatch(d=dam_data, ignore_case=False, longest_match_only=True)

In [8]:
'''
Extractor
'''
ce = CandidateExtractor(DamCandidates, [ngram_rm, ngram_names], [dam_rm_matcher, dam_name_matcher], 
                       symmetric_relations=True, nested_relations=False, self_relations=False)

In [9]:
'''
filter out sentences that have mentions of a LOCATION (NER)
Note - testing this might result in noun or proper noun 
'''
def number_of_dams(sentence):
    active_sequence = False
    count = 0 
    for tag in sentence.ner_tags:
        if tag == 'LOCATION' and not active_sequence:
            active_sequence = True
            count += 1
        elif tag != 'LOCATION' and active_sequence:
            active_sequence = False
    return count

In [10]:
'''
train-test split (from tutorial)
'''
docs = session.query(Document).order_by(Document.name).all()
ld = len(docs)

train_sents = set()
dev_sents = set()
test_sents = set()
splits = (0.8, 0.9) if 'CI' in os.environ else (0.9, 0.95)

for i, doc in enumerate(docs):
    for s in doc.sentences:
        if number_of_dams(s) < 2:
            if i < splits[0] * ld:
                train_sents.add(s)
            elif i < splits[1] * ld:
                dev_sents.add(s)
            else:
                test_sents.add(s)
                
len(train_sents), len(dev_sents), len(test_sents)

(120371, 10620, 4683)

## Running Candidate Extractor

In [11]:
#docs = session.query(Document).order_by(Document.name).all()
#ce.apply(docs, split=0)
%time ce.apply(train_sents, split=0)

Clearing existing...
Running UDF...

CPU times: user 2min 25s, sys: 2.13 s, total: 2min 27s
Wall time: 2min 29s


In [12]:
'''
get candidates
'''
train_candidates = session.query(DamCandidates).filter(DamCandidates.split == 0).all()
print('Number of candidates: ' + str(len(train_candidates)))

Number of candidates: 167


## Viewer to inspect candidates

```sh
jupyter nbextension enable --py widgetsnbextension --sys-prefix
```

In [13]:
if 'CI' not in os.environ:
    sv = SentenceNgramViewer(train_candidates[:], session)
else:
    sv = None

<IPython.core.display.Javascript object>

In [20]:
!jupyter nbextension enable --py widgetsnbextension --sys-prefix

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: problems found:
        - require? [31m X[0m jupyter-js-widgets/extension


In [29]:
!jupyter nbextension enable --py widgetsnbextension

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: problems found:
        - require? [31m X[0m jupyter-js-widgets/extension


In [None]:
sv

In [15]:
sv.get_selected()

DamCandidates(Span("removal-related", sentence=51481, chars=[220,234], words=[38,38]), Span("Mill Dam", sentence=51481, chars=[26,33], words=[6,7]))

In [17]:
if 'CI' not in os.environ:
    print(unicode(sv.get_selected()))

DamCandidates(Span("Removal", sentence=37608, chars=[0,6], words=[0,0]), Span("Edwards Dam", sentence=37608, chars=[11,21], words=[2,3]))


# Labeling Functions
## Training a model to differentiate b/t true and false mentions


In [18]:
from snorkel import CandidateSet  # no candidateset

train_candidates = session.query(CandidateSet).filter(CandidateSet.name == 'Training Candidates').one()
print(len(train_candidates))
test_candidates = session.query(CandidateSet).filter(CandidateSet.name == 'Test Candidates').one()
print(len(test_candidates))

ImportError: cannot import name CandidateSet

In [19]:
from snorkel.annotations import load_gold_labels

ImportError: No module named templates

In [None]:
# creating features
from snorkel.annotations import FeatureMa

## Session commit
note: figure out another method to save/commit changes

In [30]:
session.add(test_sents)
session.add(train_sents)
session.add(dev_sents)
session.commit()

UnmappedInstanceError: Class '__builtin__.set' is not mapped

## Traversing context hierarchy
Candidates are tuples of Context-type objects, most cases Spans

In [None]:
c = train_cands[0]
c.get_contexts()

In [None]:
span = c.get_contexts()[0]
print(span)
print(span.get_parent())
print(span.get_parent().get_parent())

In [None]:
print(span.get_span())
print(span.get_attrib_tokens())
print(span.get_attrib_tokens('pos_tags'))

In [None]:
'''
repeating for dev and test data
'''
%time
for i, sents in enumerate([dev_sents, test_sents]):
    dam_extractor.apply(sents, split = i+1)
    print('Number of candidates: ', session.query(Dam).filter(Dam.split == i+1).count())