In [107]:
import stanza
from stanza.utils.conll import CoNLL
from transformers import BertTokenizer, BertModel, pipeline
import copy
import torch

In [108]:
device = torch.device('cuda')

In [116]:
# text preprocessing
def text_preprocess(filename):
    f = open(filename, mode='r')
    sent_list = f.readlines()
    return sent_list

# takes already annotated sentences and turns into a doc if necessary
def parse_conll(filename):
    doc = CoNLL.conll2doc(filename)
    return doc

def get_masks(sentences):
    """
    [Sentence] -> {str : [str]}
    """
    masked = {}
    for s in sentences:
        sent_dict = s.to_dict()
        pert = []
        for i in range(len(sent_dict)):
            if 'upos' in sent_dict[i].keys() and sent_dict[i]['upos'] != 'PUNCT':
                dict_copy = copy.deepcopy(sent_dict)
                dict_copy[i]['text'] = '[MASK]'
                pert.append(''.join([w['text'] + ' ' if 'misc' not in w.keys() and w['id'] != (len(b),) else w['text'] for w in dict_copy]))
        masked[s.text] = pert
    return masked

def fill_masks(sentences, model_version='bert_base_uncased', number_words=20):
    """
    [Sentence] -> {str : (int, str)}
    """
    mask_dict = get_masks(sentences)
    unmasker = pipeline('fill-mask', model='bert-base-uncased', tokenizer='bert-base-uncased', device=0)
    mask_list = [mask_dict[k] for k in mask_dict.keys()]
    filled_list = unmasker(mask_list, top_k=number_words)
    filled_list = {mask_list[i] : [(word['score'], word['token_str']) for word in filled_list[i]] for i in range(len(mask_list))}
    return filled_list

In [None]:
# getting dependency parses from doc
def get_parses(sentences, pos_tagger, ned=False)
    doc_all = nlp(all_sents)
    dependency_dict = {}
    for i,sent in enumerate(doc_all.sentences):
        deplist = [(word.id, word.head, word.deprel) for word in sent.words]
        deplist = [dep for dep in deplist if dep[2] != 'root']
        dependency_dict[i] = deplist
    
    target_sents_deps_labeled = {d : dependency_dict[d] for d in dependency_dict.keys() if d % 3 == 0}
    # adds the grandparents if we want it
    if ned:
        for k in target_sents_deps_labeled.keys():
            deps = target_sents_deps_labeled[k]
            grandparents = []
            children = [d[0] for d in deps]
            heads = [d[1] for d in deps]
            for d in deps:
                head = d[1]
                if head in children:
                    grandparents.append((d[0], heads[children.index(head)], 'grand'))
            target_sents_deps_labeled[k] = [*deps, *grandparents]
    return target_sents_deps_labeled

In [None]:
if __name__ == "__main__":
    stanza.download('en')
    nlp = stanza.Pipeline('en', processors='tokenize,mwt,pos,lemma,depparse', use_gpu=True, pos_batch_size=3000)