In [42]:
from pathlib import Path
from pprint import pprint

import torch
import torch.nn as nn
import datasets
from datasets import concatenate_datasets
from tqdm import tqdm
from transformers import BertForSequenceClassification
from transformers import BertForMaskedLM, BertTokenizer
import numpy as np

from common.data_utils import get_dataset
from model.tokenizer import PhraseTokenizer
from model.attacker import Attacker
from model.substitution import *

In [4]:
import torch
a = torch.tensor([[1,2,3]])
a.size()

torch.Size([1, 3])

In [6]:
a.repeat(4,1)

tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3],
        [1, 2, 3]])

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

Using cuda


In [23]:
model_name = "bert-large-uncased-whole-word-masking"
tokenizer = BertTokenizer.from_pretrained(model_name)
mlm_model = BertForMaskedLM.from_pretrained(model_name).to(device)

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Generate Adversarial Examples for the Target Sequence (multi-granuality)

In [7]:
from model.tokenizer import PhraseTokenizer

In [8]:
tgt_seq = "Frozen 2 is bad compared to its predecessor. You cannot look into the story too much."
entry = {'text': tgt_seq}

In [9]:
phrase_tok = PhraseTokenizer()
phrase_token_output = phrase_tok.tokenize(entry)

['tok2vec', 'tagger', 'parser', 'ner', 'attribute_ruler', 'lemmatizer', 'merge_phrases']


In [10]:
phrase_token_output['phrases']

['frozen',
 '2',
 'is',
 'bad',
 'compared to',
 'its',
 'predecessor',
 '.',
 'you',
 'cannot',
 'look into',
 'the',
 'story',
 'too',
 'much',
 '.']

In [12]:
phrase_token_output['phrase_offsets']

[(0, 6),
 (7, 8),
 (9, 11),
 (12, 15),
 (16, 27),
 (28, 31),
 (32, 43),
 (43, 44),
 (45, 48),
 (49, 55),
 (56, 65),
 (66, 69),
 (70, 75),
 (76, 79),
 (80, 84),
 (84, 85)]

### Map phrase index to word index

In [83]:
p_i = 0
p_s = 0
p_e = phrase_token_output['phrase_offsets'][p_i][1]
p_len = 0
phrase2word = []
new_p = True
word_count = 0
for w_s, w_e in phrase_token_output['word_offsets']:
    
    if new_p:
        p_s = word_count
        new_p = False
    
    if w_e == p_e:
        phrase2word.append([p_s, word_count+1])
        new_p = True
        p_i = min(p_i + 1, len(phrase_token_output['phrase_offsets']) - 1)
        p_e = phrase_token_output['phrase_offsets'][p_i][1]
    
    word_count += 1

In [84]:
phrase2word

[[0, 1],
 [1, 2],
 [2, 3],
 [3, 4],
 [4, 6],
 [6, 7],
 [7, 8],
 [8, 9],
 [9, 10],
 [10, 11],
 [11, 13],
 [13, 14],
 [14, 15],
 [15, 16],
 [16, 17],
 [17, 18]]

### Add 1 to phrase_len `[MASK]`' to the target sentence

In [85]:
phrase_masked_list = []
word2char = phrase_token_output['word_offsets']

mask_index_list = []
mask_count = 0
for p_s, p_e in phrase2word:
    if p_e - p_s >= 2:
        c_s = word2char[ p_s ][0]
        c_e = word2char[ p_e - 1][1]
        
        mask_len = p_e - p_s
        for l in range(1, mask_len+1):
            phrase_masked_list.append(tgt_seq[0:c_s] + ' [MASK]' * l + ' ' + tgt_seq[c_e:])
            mask_index_list.append([mask_count, mask_count + l])
            mask_count += l

In [86]:
tgt_seq

'Frozen 2 is bad compared to its predecessor. You cannot look into the story too much.'

In [87]:
phrase_masked_list

['Frozen 2 is bad  [MASK]  its predecessor. You cannot look into the story too much.',
 'Frozen 2 is bad  [MASK] [MASK]  its predecessor. You cannot look into the story too much.',
 'Frozen 2 is bad compared to its predecessor. You cannot  [MASK]  the story too much.',
 'Frozen 2 is bad compared to its predecessor. You cannot  [MASK] [MASK]  the story too much.']

### Get masked token candidates from MLM model

In [88]:
encodings = tokenizer(phrase_masked_list, truncation=True, padding=True, return_token_type_ids=False, return_tensors='pt')
inputs = encodings['input_ids'].to(device)
mask_token_index = torch.where(inputs == tokenizer.mask_token_id)[1]

In [89]:
token_logits = mlm_model(inputs, attention_mask=encodings['attention_mask'].to(device)).logits
token_logits.shape

torch.Size([4, 20, 30522])

In [90]:
mask_token_logits = torch.empty(len(mask_token_index), token_logits.shape[2])

for i,ind in enumerate(mask_index_list):
    li_s = mask_index_list[i][0]
    li_e = mask_index_list[i][1]
    ind_s = mask_token_index[li_s]
    ind_e = mask_token_index[li_e - 1] + 1
        
    mask_token_logits[li_s:li_e] = token_logits[i, ind_s:ind_e, :]

In [91]:
top_8_tokens = torch.topk(mask_token_logits, 8, dim=1).indices
top_8_tokens.shape

torch.Size([6, 8])

### Here get_substitutes check the combination of word candidates and rank them by perplexity (cross_entropy loss)

In [92]:
def get_substitutes(substitutes, tokenizer, mlm_model):
    # all substitutes  list of list of token-id (all candidates)
    c_loss = nn.CrossEntropyLoss(reduction='none')
    word_list = []

    # find all possible candidates 
    all_substitutes = []
    for i in range(substitutes.size(0)):
        if len(all_substitutes) == 0:
            lev_i = substitutes[i]
            all_substitutes = [[int(c)] for c in lev_i]
        else:
            lev_i = []
            for all_sub in all_substitutes:
                for j in substitutes[i]:
                    lev_i.append(all_sub + [int(j)])
            all_substitutes = lev_i

    # all_substitutes = all_substitutes[:24]
    all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
    all_substitutes = all_substitutes[:24].to(device)
    
    print(all_substitutes.shape) # (K ^ t, K)

    N, L = all_substitutes.size()
    word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size
    ppl = c_loss(word_predictions.view(N*L, -1), all_substitutes.view(-1)) # [ N*L ] 
    ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N  
    
    _, word_list = torch.sort(ppl)
    word_list = [all_substitutes[i] for i in word_list]
    final_words = []
    for word in word_list[:24]:
        tokens = [tokenizer._convert_id_to_token(int(i)) for i in word]
        text = tokenizer.convert_tokens_to_string(tokens)
        final_words.append(text)
        
    del all_substitutes
    return final_words

In [93]:
for (i, (p_s, p_e)) in enumerate(mask_index_list):
    cur_phrase = ''
    substitutes = top_8_tokens[p_s:p_e]
    final_words = get_substitutes(substitutes, tokenizer, mlm_model)
    for s in substitutes:
        print(tokenizer.convert_ids_to_tokens(s))
    
    print(phrase_masked_list[i])
    for w in final_words[:5]:
        print(phrase_masked_list[i].replace((f' {tokenizer.mask_token}' * (p_e - p_s))[1:], w))
    print()

torch.Size([8, 1])
['as', 'like', 'for', 'after', 'enough', 'unlike', 'than', 'from']
Frozen 2 is bad  [MASK]  its predecessor. You cannot look into the story too much.
Frozen 2 is bad  as  its predecessor. You cannot look into the story too much.
Frozen 2 is bad  for  its predecessor. You cannot look into the story too much.
Frozen 2 is bad  after  its predecessor. You cannot look into the story too much.
Frozen 2 is bad  like  its predecessor. You cannot look into the story too much.
Frozen 2 is bad  than  its predecessor. You cannot look into the story too much.

torch.Size([24, 2])
['enough', 'compared', 'news', ',', 'as', 'business', 'tempered', 'looking']
['as', 'like', 'to', 'from', 'than', 'for', 'of', 'with']
Frozen 2 is bad  [MASK] [MASK]  its predecessor. You cannot look into the story too much.
Frozen 2 is bad  news of  its predecessor. You cannot look into the story too much.
Frozen 2 is bad  news for  its predecessor. You cannot look into the story too much.
Frozen 2 is b

## Importance Score

In [94]:
from transformers import AutoModelForSequenceClassification

target_model = AutoModelForSequenceClassification.from_pretrained("textattack/distilbert-base-uncased-imdb").to(device)

In [95]:
model_name = "bert-large-uncased-whole-word-masking"
tokenizer = BertTokenizer.from_pretrained(model_name)

In [96]:
# 1. retrieve logits and label from the target model
inputs = tokenizer(entry['text'], return_tensors="pt", truncation=True, max_length=512, return_token_type_ids=False)
orig_logits = target_model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device))[0].squeeze()
orig_probs  = torch.softmax(orig_logits, -1)
orig_label = torch.argmax(orig_probs)
current_prob = orig_probs.max()

### Mask each phrase with `[UNK]` token and compute the confidence change

In [97]:
# return units masked with UNK at each position in the sequence
def _get_unk_masked(units):
    len_text = len(units)
    masked_units = []
    for i in range(len_text - 1):
        masked_units.append(units[0:i] + ['[UNK]'] + units[i + 1:])
    
    # list of masked basic units
    return masked_units

'''
input units should be phrase tokens
'''
def get_important_scores(units, tgt_model, orig_prob, orig_label, orig_probs, tokenizer, batch_size=8, max_length=512):
    masked_units = _get_unk_masked(units)
    texts = [' '.join(units) for units in masked_units]  # list of text of masked units
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print(tokenizer)
    encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=max_length, return_token_type_ids=False, return_tensors='pt')
    
    eval_data = TensorDataset(encodings['input_ids'], encodings['attention_mask'])

    # Run prediction for full data
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=batch_size)
    leave_1_probs = []
    
    tgt_model.eval() #make sure in inference stage
    
    with torch.no_grad():
        for batch in eval_dataloader:
            input_ids = batch[0].to(device)      # input ids
            attention_mask = batch[1].to(device) # attention mask
        
            leave_1_prob_batch = tgt_model(input_ids, attention_mask=attention_mask)[0]
            leave_1_probs.append(leave_1_prob_batch)
        
    leave_1_probs = torch.cat(leave_1_probs, dim=0)  # words, num-label
    leave_1_probs = torch.softmax(leave_1_probs, -1)
    leave_1_probs_argmax = torch.argmax(leave_1_probs, dim=-1)
    import_scores = (orig_prob
                     - leave_1_probs[:, orig_label] # how the probability of original label decreases
                     +
                     (leave_1_probs_argmax != orig_label).float() # new label not equal to original label
                     * (leave_1_probs.max(dim=-1)[0] - torch.index_select(orig_probs, 0, leave_1_probs_argmax))
                     ).data.cpu().numpy()           # probability of changed label

    return import_scores

In [98]:
importance = get_important_scores(entry['phrases'], target_model, current_prob, orig_label, orig_probs, tokenizer, batch_size=8, max_length=512)

PreTrainedTokenizer(name_or_path='bert-large-uncased-whole-word-masking', vocab_size=30522, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})


In [99]:
sorted_indices = torch.argsort(torch.tensor(importance), dim=-1, descending=True).cpu()
sorted_units = np.array(entry['phrases'])[sorted_indices]
[(u,i) for (u,i) in zip(sorted_units, importance[sorted_indices])]

[('bad', 1.3867681),
 ('.', 0.036233902),
 ('story', 0.019454658),
 ('predecessor', 0.015556216),
 ('the', 0.01033473),
 ('much', 0.007848859),
 ('its', 0.007283628),
 ('you', 0.005296588),
 ('frozen', 0.0013504624),
 ('look into', 0.00082850456),
 ('2', 0.00012338161),
 ('cannot', -0.0028142333),
 ('too', -0.007349491),
 ('compared to', -0.012966096),
 ('is', -0.016696215)]

## Semantic Constraint

### Mask the Word 'Bad'

In [102]:
tgt_seq.find('bad')

12

In [103]:
phrase_masked_list = (tgt_seq[0:12] + ' [MASK] ' + tgt_seq[15:])
phrase_masked_list

'Frozen 2 is  [MASK]  compared to its predecessor. You cannot look into the story too much.'

In [104]:
encodings = tokenizer(phrase_masked_list, truncation=True, padding=True, return_token_type_ids=False, return_tensors='pt')
inputs = encodings['input_ids'].to(device)
mask_token_index = torch.where(inputs == tokenizer.mask_token_id)[1]
token_logits = mlm_model(inputs, attention_mask=encodings['attention_mask'].to(device)).logits
mask_token_logits = token_logits[0, mask_token_index, :]
top_8_tokens = torch.topk(mask_token_logits, 8, dim=1).indices

In [105]:
print(phrase_masked_list)
for t in tokenizer.convert_ids_to_tokens(top_8_tokens[0]):
    print(phrase_masked_list.replace(f' {tokenizer.mask_token} ', t))

Frozen 2 is  [MASK]  compared to its predecessor. You cannot look into the story too much.
Frozen 2 is nothing compared to its predecessor. You cannot look into the story too much.
Frozen 2 is tame compared to its predecessor. You cannot look into the story too much.
Frozen 2 is small compared to its predecessor. You cannot look into the story too much.
Frozen 2 is simple compared to its predecessor. You cannot look into the story too much.
Frozen 2 is weak compared to its predecessor. You cannot look into the story too much.
Frozen 2 is boring compared to its predecessor. You cannot look into the story too much.
Frozen 2 is disappointing compared to its predecessor. You cannot look into the story too much.
Frozen 2 is short compared to its predecessor. You cannot look into the story too much.


### No Mask - implicit semantic check

In [107]:
phrase_masked_list = tgt_seq

encodings = tokenizer(phrase_masked_list, truncation=True, padding=True, return_token_type_ids=False, return_tensors='pt')
inputs = encodings['input_ids'].to(device)
mask_token_index = tokenizer.convert_ids_to_tokens(inputs[0]).index('bad')
token_logits = mlm_model(inputs, attention_mask=encodings['attention_mask'].to(device)).logits
mask_token_logits = token_logits[0, mask_token_index, :]
top_8_tokens = torch.topk(mask_token_logits, 8, dim=-1).indices

In [108]:
print(phrase_masked_list)
for t in tokenizer.convert_ids_to_tokens(top_8_tokens):
    print(phrase_masked_list.replace(f'bad', t))

Frozen 2 is bad compared to its predecessor. You cannot look into the story too much.
Frozen 2 is bad compared to its predecessor. You cannot look into the story too much.
Frozen 2 is good compared to its predecessor. You cannot look into the story too much.
Frozen 2 is poor compared to its predecessor. You cannot look into the story too much.
Frozen 2 is worst compared to its predecessor. You cannot look into the story too much.
Frozen 2 is horrible compared to its predecessor. You cannot look into the story too much.
Frozen 2 is like compared to its predecessor. You cannot look into the story too much.
Frozen 2 is worse compared to its predecessor. You cannot look into the story too much.
Frozen 2 is terrible compared to its predecessor. You cannot look into the story too much.


Observation:
**Not deleting the word to be masked out does enforce semantic meaning.**

### What about multi-words?

In [127]:
mask_token_index = torch.tensor([12,13])
token_logits = mlm_model(inputs, attention_mask=encodings['attention_mask'].to(device)).logits
mask_token_logits = token_logits[0, mask_token_index, :]
top_8_tokens = torch.topk(mask_token_logits, 8, dim=-1).indices

In [128]:
final_words = get_substitutes(top_8_tokens, tokenizer, mlm_model)

torch.Size([24, 2])


In [130]:
print(phrase_masked_list)
for w in final_words[:5]:
    print(phrase_masked_list.replace((f'look into'), w))
print()

Frozen 2 is bad compared to its predecessor. You cannot look into the story too much.
Frozen 2 is bad compared to its predecessor. You cannot see in the story too much.
Frozen 2 is bad compared to its predecessor. You cannot see into the story too much.
Frozen 2 is bad compared to its predecessor. You cannot see after the story too much.
Frozen 2 is bad compared to its predecessor. You cannot see about the story too much.
Frozen 2 is bad compared to its predecessor. You cannot see behind the story too much.



Observation:
**Should not do this for phrases. Since it still enforces single-word semangtic meaning**