In [1]:
import stanza
from stanza.utils.conll import CoNLL
from transformers import BertTokenizer, BertModel, pipeline
import copy
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda')

In [69]:
# text preprocessing
def text_preprocess(filename):
    f = open(filename, mode='r')
    sent_list = f.readlines()
    return sent_list

# takes already annotated sentences and turns into a doc if necessary
def parse_conll(filename):
    doc = CoNLL.conll2doc(filename)
    return doc

def get_masks(sentences):
    """
    [Sentence] -> {str : [str]}
    """
    masked = {}
    for s in sentences:
        sent_dict = s.to_dict()
        pert = []
        for i in range(len(sent_dict)):
            if 'upos' in sent_dict[i].keys() and sent_dict[i]['upos'] != 'PUNCT':
                dict_copy = copy.deepcopy(sent_dict)
                dict_copy[i]['text'] = '[MASK]'
                pert.append((i, ''.join([w['text'] + ' ' if 'misc' not in w.keys() and w['id'] != len(sent_dict) else w['text'] for w in dict_copy])))
        masked[s.text] = pert
    return masked

# gets the scores for each token in the masked sentence

def fill_masks(masked_sentences, model_version='bert_base_uncased', number_words=20):
    """
    [Sentence] -> {str : (int, str)}
    """
    mask_dict = masked_sentences
    unmasker = pipeline('fill-mask', model='bert-base-uncased', tokenizer='bert-base-uncased', device=0)
    mask_list = [mask[1] for k in mask_dict.keys() for mask in mask_dict[k]]
    filled_list = unmasker(mask_list, top_k=number_words)
    filled_list = {mask_list[i] : [(word['score'], word['token_str']) for word in filled_list[i]] for i in range(len(mask_list))}
    return filled_list

def filter_by_pos(dependency_doc, filled_masks, pos_tagger):
    mask_dict = get_masks(dependency_doc.sentences)
    filtered_dict = {}
    for s in dependency_doc.sentences():
        sentence_text = s.text
        masked_sentences = mask_dict[sentence_text]
        for i, masked_sentence in enumerate(masked_sentences):
            filtered_tokens = []
            token_id = masked_sentence[0]
            dict_copy = copy.deepcopy(s.to_dict())
            pos = s.to_dict()[i].upos
            filled_list = filled_masks[masked_sentence]
            pert_sentences = []
            # get a list of all the proposal perturbed sentences
            for w in filled_list:
                dict_copy[token_id]['text'] = w[1]
                pert_sentence = ''.join([w['text'] + ' ' if 'misc' not in w.keys() and w['id'] != (len(s.text)) else w['text'] for w in dict_copy])
                pert_sentences.append(pert_sentence)
            pert_sentences_joined = '\n\n'.join(pert_sentences)
            pos_doc = pos_tagger(pert_setences_joined)
            for j, pert_s in enumerate(pos_doc.sentences):
                if pert_s.words[token_id].upos == pos:
                    filtered_tokens.append(filled_list[j][1])
            if s.text in filtered_dict.keys():
                filtered_dict[s.text].append((i, filtered_tokens))
            else:
                filtered_dict[s.text] = [(i, filtered_tokens)]
    return filtered_dict

In [50]:
token_list = get_masks(d.sentences)

In [None]:
filled_masks = fill_masks(token_list)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [22]:
d = parse_conll('datasets/penntreebank-ewt.conllu')
a = d.sentences[1]

a = d.sentences[3]
a.to_dict()[1]['upos']

d.sentences[3].words[1].upos

In [4]:
# getting dependency parses from doc
def get_parses(sentences, pos_tagger, ned=False)
    doc_all = nlp(all_sents)
    dependency_dict = {}
    for i,sent in enumerate(doc_all.sentences):
        deplist = [(word.id, word.head, word.deprel) for word in sent.words]
        deplist = [dep for dep in deplist if dep[2] != 'root']
        dependency_dict[i] = deplist
    
    target_sents_deps_labeled = {d : dependency_dict[d] for d in dependency_dict.keys() if d % 3 == 0}
    # adds the grandparents if we want it
    if ned:
        for k in target_sents_deps_labeled.keys():
            deps = target_sents_deps_labeled[k]
            grandparents = []
            children = [d[0] for d in deps]
            heads = [d[1] for d in deps]
            for d in deps:
                head = d[1]
                if head in children:
                    grandparents.append((d[0], heads[children.index(head)], 'grand'))
            target_sents_deps_labeled[k] = [*deps, *grandparents]
    return target_sents_deps_labeled

SyntaxError: invalid syntax (2900591857.py, line 2)

In [6]:
if __name__ == "__main__":
    stanza.download('en')
    nlp = stanza.Pipeline('en', processors='tokenize,mwt,pos,lemma,depparse', use_gpu=True, pos_batch_size=3000)

Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.4.0.json: 154kB [00:00, 32.6MB/s]                    
2022-05-30 11:22:18 INFO: Downloading default packages for language: en (English)...
Downloading https://huggingface.co/stanfordnlp/stanza-en/resolve/v1.4.0/models/default.zip: 100%|██████████| 479M/479M [00:09<00:00, 49.7MB/s] 
2022-05-30 11:22:35 INFO: Finished downloading models and saved to /home/mila/j/jasper.jian/stanza_resources.
Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.4.0.json: 154kB [00:00, 42.2MB/s]                    
2022-05-30 11:22:35 INFO: Loading these models for language: en (English):
| Processor | Package  |
------------------------
| tokenize  | combined |
| pos       | combined |
| lemma     | combined |
| depparse  | combined |

2022-05-30 11:22:35 INFO: Use device: gpu
2022-05-30 11:22:35 INFO: Loading: tokenize
2022-05-30 11:22:44 INFO: Loading: pos
2022-05-30 11:22:4