In [68]:
import stanza
from stanza.utils.conll import CoNLL
from transformers import BertTokenizer, BertModel, pipeline
import copy
import torch
import pickle
import random

In [56]:
# text preprocessing
def text_preprocess(filename):
    f = open(filename, mode='r')
    sent_list = f.readlines()
    return sent_list

# takes already annotated sentences and turns into a doc if necessary
def parse_conll(filename):
    doc = CoNLL.conll2doc(filename)
    return doc

def get_masks(sentences):
    """
    [Sentence] -> {str : [str]}
    """
    masked = {}
    for s in sentences:
        sent_dict = s.to_dict()
        pert = []
        for i in range(len(sent_dict)):
            if 'upos' in sent_dict[i].keys() and sent_dict[i]['upos'] != 'PUNCT':
                dict_copy = copy.deepcopy(sent_dict)
                dict_copy[i]['text'] = '[MASK]'
                pert.append((i, ''.join([w['text'] + ' ' if 'misc' not in w.keys() and w['id'] != len(sent_dict) else w['text'] for w in dict_copy])))
        masked[s.text] = pert
    return masked

# gets the scores for each token in the masked sentence

def fill_masks(masked_sentences, model_version='bert_base_uncased', number_words=20):
    """
    [Sentence] -> {str : (int, str)}
    """
    mask_dict = masked_sentences
    unmasker = pipeline('fill-mask', model='bert-base-uncased', tokenizer='bert-base-uncased', device=0)
    mask_list = [mask[1] for k in mask_dict.keys() for mask in mask_dict[k]]
    filled_list = unmasker(mask_list, top_k=number_words)
    filled_list = {mask_list[i] : [(word['score'], word['token_str']) for word in filled_list[i]] for i in range(len(mask_list))}
    return filled_list

def filter_by_pos(dependency_doc, filled_masks, pos_tagger):
    mask_dict = get_masks(dependency_doc)
    filtered_dict = {}
    for s in dependency_doc:
        sentence_text = s.text
        masked_sentences = mask_dict[sentence_text]
        for masked_sentence in masked_sentences:
            i = masked_sentence[0]
            masked_sentence = masked_sentence[1]
            filtered_tokens = []
            token_id = i
            dict_copy = copy.deepcopy(s.to_dict())
            pos = s.words[i].upos
            filled_list = filled_masks[masked_sentence]
            pert_sentences = []
            # get a list of all the proposal perturbed sentences
            for w in filled_list:
                dict_copy[token_id]['text'] = w[1]
                pert_sentence = ''.join([w['text'] + ' ' if 'misc' not in w.keys() and w['id'] != (len(s.text)) else w['text'] for w in dict_copy])
                pert_sentences.append(pert_sentence)
            pert_sentences_joined = '\n\n'.join(pert_sentences)
            pos_doc = pos_tagger(pert_sentences_joined)
            for j, pert_s in enumerate(pos_doc.sentences):
                print(pert_s.text)
                print(s.text)
                if pert_s.words[token_id].upos == pos:
                    filtered_tokens.append(filled_list[j][1])
            if s.text in filtered_dict.keys():
                filtered_dict[s.text].append((i, filtered_tokens))
            else:
                filtered_dict[s.text] = [(i, filtered_tokens)]
    return filtered_dict

def get_all_words(sentences):
    all_words = {}
    for s in sentences:
        for w in s.words:
            properties =  (w.upos, w.feats)
            if properties not in all_words.keys():
                all_words[properties] = {w.text.lower()}
            else:
                all_words[properties].add(w.text.lower())
    return all_words

In [148]:
def fill_by_template(sentences, word_dict, number_sentences=3, number_of_perturbations=2):
    perturbed_sentences = {}
    perturbed_categories = ['ADJ', 'ADV', 'NOUN', 'VERB', 'ADP']
    for num_s, s in enumerate(sentences):
        if num_s % 300 == 0:
            print(num_s)
        s_dict = s.to_dict()
        pert_list = []
        #print(s.text)
        for iteration in range(number_sentences):
            list_of_positions = [i for i in range(len(s.words)) if s.words[i].upos in perturbed_categories]
            if len(list_of_positions) < number_of_perturbations:
                continue
            sampled_positions = random.choices(list_of_positions, k=number_of_perturbations)
            sampled_positions = [(i, s.words[i].upos, s.words[i].feats, s.words[i].text) for i in sampled_positions]
            dict_copy = copy.deepcopy(s.to_dict())
            for position in sampled_positions:
                features = (position[1], position[2])
                possible_replacements = tuple(word_dict[features])
                sampled_word = random.choice(possible_replacements)
                while sampled_word == position[3].lower() and len(possible_replacements) > 1:
                    sampled_word = random.choice(possible_replacements)
                dict_copy[position[0]]['text'] = sampled_word
            pert_sentence = ''.join(['' if type(w['id']) == tuple or w['upos'] == 'PUNCT' else w['text'] if i == (len(dict_copy) - 1) else w['text'] + ' ' for i, w in enumerate(dict_copy)])
            pert_list.append(pert_sentence.strip())
        perturbed_sentences[s.text] = pert_list
    return perturbed_sentences

In [180]:
def update_parses(sentences):
    updated_doc = []
    for s in sentences:
        s_dict = s.to_dict()
        updated_parse = [w for w in s_dict if type(w['id']) != tuple and w['upos'] != 'PUNCT']
        updated_doc.append(updated_parse)
    updated_doc = stanza.models.common.doc.Document(updated_doc)
    return updated_doc

In [174]:
d_updated = update_parses(d_prime)

In [179]:
d_updated.sentences[34].to_dict()

[{'id': 2,
  'text': 'Bush',
  'lemma': 'Bush',
  'upos': 'PROPN',
  'xpos': 'NNP',
  'feats': 'Number=Sing',
  'head': 3,
  'deprel': 'nsubj',
  'deps': '3:nsubj'},
 {'id': 3,
  'text': 'fails',
  'lemma': 'fail',
  'upos': 'VERB',
  'xpos': 'VBZ',
  'feats': 'Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin',
  'head': 0,
  'deprel': 'root',
  'deps': '0:root'},
 {'id': 4,
  'text': 'reporter',
  'lemma': 'reporter',
  'upos': 'NOUN',
  'xpos': 'NN',
  'feats': 'Number=Sing',
  'head': 7,
  'deprel': 'nmod:poss',
  'deps': '7:nmod:poss'},
 {'id': 5,
  'text': "'s",
  'lemma': "'s",
  'upos': 'PART',
  'xpos': 'POS',
  'head': 4,
  'deprel': 'case',
  'deps': '4:case'},
 {'id': 6,
  'text': 'pop',
  'lemma': 'pop',
  'upos': 'NOUN',
  'xpos': 'NN',
  'feats': 'Number=Sing',
  'head': 7,
  'deprel': 'compound',
  'deps': '7:compound'},
 {'id': 7,
  'text': 'quiz',
  'lemma': 'quiz',
  'upos': 'NOUN',
  'xpos': 'NN',
  'feats': 'Number=Sing',
  'head': 3,
  'deprel': 'obj',
  'deps

In [3]:
# getting dependency parses from doc
def get_parses(sentences, pos_tagger, ned=False):
    doc_all = nlp(all_sents)
    dependency_dict = {}
    for i,sent in enumerate(doc_all.sentences):
        deplist = [(word.id, word.head, word.deprel) for word in sent.words]
        deplist = [dep for dep in deplist if dep[2] != 'root']
        dependency_dict[i] = deplist
    target_sents_deps_labeled = {d : dependency_dict[d] for d in dependency_dict.keys() if d % 3 == 0}
    # adds the grandparents if we want it
    if ned:
        for k in target_sents_deps_labeled.keys():
            deps = target_sents_deps_labeled[k]
            grandparents = []
            children = [d[0] for d in deps]
            heads = [d[1] for d in deps]
            for d in deps:
                head = d[1]
                if head in children:
                    grandparents.append((d[0], heads[children.index(head)], 'grand'))
            target_sents_deps_labeled[k] = [*deps, *grandparents]
    return target_sents_deps_labeled

In [149]:
%%time
if __name__ == "__main__":
    #stanza.download('en')
    #nlp = stanza.Pipeline('en', processors='tokenize,mwt,pos,lemma,depparse', use_gpu=True, pos_batch_size=3000)
    d = parse_conll('datasets/penntreebank-ewt.conllu')
    d_prime = [s for s in d.sentences if len(s.words) <= 15 and len(s.words) >= 4]
    #masked_sentences = get_masks(d_prime)
    #filled_masks = fill_masks(masked_sentences)
    """
    with open('test_fill_replacement.pkl', 'wb') as f:
        pickle.dump(test_fill, f)
    """
    labeled_words = get_all_words(d.sentences)
    test_fill = fill_by_template(d_prime, labeled_words, number_sentences=9, number_of_perturbations=2)
    with open('test_fill_replacement.pkl', 'wb') as f:
        pickle.dump(test_fill, f)

0
300
600
900
1200
1500
1800
2100
2400
2700
3000
3300
3600
3900
4200
4500
4800
5100
CPU times: user 13.1 s, sys: 200 ms, total: 13.3 s
Wall time: 13.3 s


In [165]:
for s in d_prime:
    punct_list = []
    for w in s.words:
        if w.text == ':':
            punct_list.append(w.id)
            print(w.deprel)
            print(s.text)
            print(list(s.to_dict())[w.head - 1]['text'])
    for w in s.words:
        if w.head in punct_list:
            print(s.text)
            continue

punct
CHERNOBYL ACCIDENT: TEN YEARS ON
ACCIDENT
punct
Here are some excerpts:
Here
punct
Let's just remember a seminal Bush moment in 1999:
Let's
punct
Web posted at: 3:29 p.m. EST (2029 GMT)
p.m.
punct
you can check it out : los angeles online dating
dating
punct
WADE GOODWYN reporting:
reporting
punct
Ms. BABA GROOM (Former Campaign Worker):
Ms.
punct
Mr. ARCHIBALD: People have different ways of starting the days in any office.
Mr.
punct
(also see this thread on the RI discussion board):
see
punct
And it's guilty :
's
punct
From SNAP's statement on the Robinson conviction:
's
punct
The murder weapon, Robinson's letter opener:
weapon
punct
SPLOID.com cited you on the topic of that priest conviction: http://www.sploid.com/news/2006/05/evil_priest_gui.php
http://www.sploid.com/news/2006/05/evil_priest_gui.php
punct
The following quote is from that Coogan article:
Coogan
punct
Here's another interesting example:
Here's
punct
Not 200,000 guns - the numbers dont work:
guns
punct
Important 