## Syntax Augmentation with BERT

### Requirements: 
* pytorch (1.7.0) <br>
* transformers (3.1.0) <br>

First time initializing the class will take some time as it will download the EstBERT model. 

In [1]:
import torch # pytorch-1.7.0
from transformers import AutoModelWithLMHead, AutoTokenizer # transformers-3.1.0
import random
from copy import deepcopy


class BertSyntaxAugmenter():
    def __init__(self, model_name:str='tartuNLP/EstBERT_512'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelWithLMHead.from_pretrained(model_name, return_dict=True)
        
        
    def predict_10_mask(self, sentence:list)->(list, list, int):
        # chooses random word, replaces it with 10 different words predicted by BERT 
        seq = deepcopy(sentence)
        mask1 = random.choice([i for i in range(len(sentence))])
        seq[mask1] = '[MASK]'
        sequence = ' '.join(seq)
        input = self.tokenizer.encode(sequence, return_tensors="pt")
        mask_token_index = torch.where(input == self.tokenizer.mask_token_id)[1]

        token_logits = self.model(input).logits
        mask_token_logits = token_logits[0, mask_token_index, :]
        top_10_tokens = torch.topk(mask_token_logits, 10, dim=1).indices[0].tolist()
        
        tokens, texts = [], []
        for token in top_10_tokens:
            tokens.append(self.tokenizer.decode([token]))
            texts.append(sequence.replace(self.tokenizer.mask_token, self.tokenizer.decode([token])))

        return tokens, texts, mask1
    
    
    def predict_x_many_to_i(self,sentence:list, x:int, i:int)->(list, list):
        # predicts x new words to replace word in i-th place
        seq = deepcopy(sentence)
        seq[i] = '[MASK]'
        sequence = ' '.join(seq)
        input = self.tokenizer.encode(sequence, return_tensors="pt")
        mask_token_index = torch.where(input == self.tokenizer.mask_token_id)[1]

        token_logits = self.model(input).logits
        mask_token_logits = token_logits[0, mask_token_index, :]
        top_x_tokens = torch.topk(mask_token_logits,x, dim=1).indices[0].tolist()
        tokens, texts = [], []
        for token in top_x_tokens:
            tokens.append(self.tokenizer.decode([token]))
            texts.append(sequence.replace(self.tokenizer.mask_token, self.tokenizer.decode([token])))

        return tokens, texts
    
    def predict_mask_constraint(self, sentence:list, poses:list, suitable_pos:set, x:int)->(list, list, int):
        # filters words based on their POS, predicts x words to suitable index
        seq = deepcopy(sentence)
        mask1 = random.choice([i for i, word in enumerate(sentence) if poses[i] in pos ])
        seq[mask1] = '[MASK]'
        sequence = ' '.join(seq)
        input = self.tokenizer.encode(sequence, return_tensors="pt")
        mask_token_index = torch.where(input == self.tokenizer.mask_token_id)[1]

        token_logits = self.model(input).logits
        mask_token_logits = token_logits[0, mask_token_index, :]
        top_x_tokens = torch.topk(mask_token_logits, x, dim=1).indices[0].tolist()
        tokens, texts = [], []
        for token in top_x_tokens:
            tokens.append(self.tokenizer.decode([token]))
            texts.append(sequence.replace(self.tokenizer.mask_token, self.tokenizer.decode([token])))

        return tokens, texts, mask1
    

        

Initializing the bert model and tokenizer: 

In [2]:
augmenter = BertSyntaxAugmenter()

Some weights of BertForMaskedLM were not initialized from the model checkpoint at tartuNLP/EstBERT_512 and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


We can predict the word in a sentence in different ways. For that, BertSyntaxAugmenter have some methods that we can use. It is possible to add your own if these are not doing what is needed, It should be easy to do following the already implemented ones.

First, we can predict 10 words to replace randomly chosen word from sentence with method predict_10_mask. Sentence should be a list. It will return all the words predicted in a list, all the new sentences created and the index, where the word was replaced in the sentece. 

Example:

In [3]:
sentence = ['Milliseks', 'kujuneb', 'Riigikassa', 'ja', 'Ühispanga', 'vahekord', '?']
augmenter.predict_10_mask(sentence=sentence)

(['kujuneb',
  'muutub',
  'jääb',
  'kujunes',
  'on',
  'saab',
  'sõltub',
  'muutuks',
  'kujunevad',
  'areneb'],
 ['Milliseks kujuneb Riigikassa ja Ühispanga vahekord ?',
  'Milliseks muutub Riigikassa ja Ühispanga vahekord ?',
  'Milliseks jääb Riigikassa ja Ühispanga vahekord ?',
  'Milliseks kujunes Riigikassa ja Ühispanga vahekord ?',
  'Milliseks on Riigikassa ja Ühispanga vahekord ?',
  'Milliseks saab Riigikassa ja Ühispanga vahekord ?',
  'Milliseks sõltub Riigikassa ja Ühispanga vahekord ?',
  'Milliseks muutuks Riigikassa ja Ühispanga vahekord ?',
  'Milliseks kujunevad Riigikassa ja Ühispanga vahekord ?',
  'Milliseks areneb Riigikassa ja Ühispanga vahekord ?'],
 1)

Second, we can customize how many words we want to be predicted and to where with method predict_x_many_to_i. Output is the same except no index is returned. 

Example: 

In [4]:
augmenter.predict_x_many_to_i(sentence=sentence, x=2, i=1)

(['kujuneb', 'muutub'],
 ['Milliseks kujuneb Riigikassa ja Ühispanga vahekord ?',
  'Milliseks muutub Riigikassa ja Ühispanga vahekord ?'])

Third, we can filter the suitable indices by POS tags. We can do that with method predict_mask_constraint.

Example (only predicting to place where POS is S): 

In [5]:
from estnltk import Text
text = Text(' '.join(sentence))
text.analyse('all')
poses = [pos[0] for pos in text.morph_analysis.partofspeech] # need to put all POS tags to a list
pos = {'S'}

In [6]:
poses

['P', 'V', 'S', 'J', 'H', 'S', 'Z']

In [7]:
augmenter.predict_mask_constraint(sentence=sentence, poses=poses, suitable_pos=pos, x=3)

(['mõju', 'suhe', 'vahel'],
 ['Milliseks kujuneb Riigikassa ja Ühispanga mõju ?',
  'Milliseks kujuneb Riigikassa ja Ühispanga suhe ?',
  'Milliseks kujuneb Riigikassa ja Ühispanga vahel ?'],
 5)