In [24]:
from pathlib import Path
from pprint import pprint

import torch  
import datasets
from datasets import concatenate_datasets
from tqdm import tqdm
from transformers import (
  AutoTokenizer,
  BertForSequenceClassification,
  AutoModelForMaskedLM,
  pipeline
)

from common.data_utils import get_dataset
from model.tokenizer import PhraseTokenizer
from model.attacker import Attacker
from model.importance import *

In [2]:
import torch.nn as nn

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

Using cuda


In [5]:
from transformers import BertForMaskedLM, BertTokenizer

In [6]:
model_name = "bert-large-uncased-whole-word-masking"
tokenizer = BertTokenizer.from_pretrained(model_name)
mlm_model = BertForMaskedLM.from_pretrained(model_name).to(device)

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
tgt_seq = "What a pity! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process."
entry = {'text': tgt_seq}

In [8]:
phrase_tok = PhraseTokenizer()
phrase_token_output = phrase_tok.tokenize(entry)

['tagger', 'parser', 'ner', 'merge_noun_chunks', 'merge_entities']


In [9]:
phrase_token_output

{'text': 'what a pity! frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 'words': ['what',
  'a',
  'pity',
  '!',
  'frozen',
  '2',
  'is',
  'bad',
  'compared',
  'to',
  'its',
  'predecessor',
  ',',
  'which',
  'is',
  'possibly',
  'due',
  'to',
  'its',
  'chaotic',
  'production',
  'process',
  '.'],
 'word_offsets': [(0, 4),
  (5, 6),
  (7, 11),
  (11, 12),
  (13, 19),
  (20, 21),
  (22, 24),
  (25, 28),
  (29, 37),
  (38, 40),
  (41, 44),
  (45, 56),
  (56, 57),
  (58, 63),
  (64, 66),
  (67, 75),
  (76, 79),
  (80, 82),
  (83, 86),
  (87, 94),
  (95, 105),
  (106, 113),
  (113, 114)],
 'phrases': ['what a pity',
  '!',
  'frozen 2',
  'is',
  'bad',
  'compared',
  'to',
  'its predecessor',
  ',',
  'which',
  'is',
  'possibly',
  'due',
  'to',
  'its chaotic production process',
  '.'],
 'phrase_offsets': [(0, 11),
  (11, 12),
  (13, 21),
  (22, 24),
  (25, 28),
  (29, 37),
  (38, 40),
  (41, 56),
  (56, 57),
  

In [10]:
p_i = 0
p_s = 0
p_e = phrase_token_output['phrase_offsets'][p_i][1]
p_len = 0
phrase2word = []
new_p = True
word_count = 0
for w_s, w_e in phrase_token_output['word_offsets']:
    
    if new_p:
        p_s = word_count
        new_p = False
    
    if w_e == p_e:
        phrase2word.append([p_s, word_count+1])
        new_p = True
        p_i = min(p_i + 1, len(phrase_token_output['phrase_offsets']) - 1)
        p_e = phrase_token_output['phrase_offsets'][p_i][1]
    
    word_count += 1
        

In [11]:
phrase2word

[[0, 3],
 [3, 4],
 [4, 6],
 [6, 7],
 [7, 8],
 [8, 9],
 [9, 10],
 [10, 12],
 [12, 13],
 [13, 14],
 [14, 15],
 [15, 16],
 [16, 17],
 [17, 18],
 [18, 22],
 [22, 23]]

In [12]:
phrase_masked_list = []
word2char = phrase_token_output['word_offsets']

mask_index_list = []
mask_count = 0
for p_s, p_e in phrase2word:
    if p_e - p_s >= 2:
        c_s = word2char[ p_s ][0]
        c_e = word2char[ p_e - 1][1]
        
        mask_len = p_e - p_s
        for l in range(1, mask_len+1):
            phrase_masked_list.append(tgt_seq[0:c_s] + ' [MASK]' * l + ' ' + tgt_seq[c_e:])
            mask_index_list.append([mask_count, mask_count + l])
            mask_count += l

In [13]:
phrase_masked_list

[' [MASK] ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 ' [MASK] [MASK] ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 ' [MASK] [MASK] [MASK] ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 'What a pity!  [MASK]  is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 'What a pity!  [MASK] [MASK]  is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 'What a pity! Frozen 2 is bad compared to  [MASK] , which is possibly due to its chaotic production process.',
 'What a pity! Frozen 2 is bad compared to  [MASK] [MASK] , which is possibly due to its chaotic production process.',
 'What a pity! Frozen 2 is bad compared to its predecessor, which is possibly due to  [MASK] .',
 'What a pity! Frozen 2 is bad compared to its predecessor

In [14]:
phrase_masked_list

[' [MASK] ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 ' [MASK] [MASK] ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 ' [MASK] [MASK] [MASK] ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 'What a pity!  [MASK]  is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 'What a pity!  [MASK] [MASK]  is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 'What a pity! Frozen 2 is bad compared to  [MASK] , which is possibly due to its chaotic production process.',
 'What a pity! Frozen 2 is bad compared to  [MASK] [MASK] , which is possibly due to its chaotic production process.',
 'What a pity! Frozen 2 is bad compared to its predecessor, which is possibly due to  [MASK] .',
 'What a pity! Frozen 2 is bad compared to its predecessor

In [15]:
encodings = tokenizer(phrase_masked_list, truncation=True, padding=True, return_token_type_ids=False, return_tensors='pt')
encodings['input_ids'].shape

torch.Size([11, 25])

In [16]:
inputs = encodings['input_ids'].to(device)
mask_token_index = torch.where(inputs == tokenizer.mask_token_id)[1]

In [17]:
mask_token_index

tensor([ 1,  1,  2,  1,  2,  3,  5,  5,  6, 11, 11, 12, 19, 19, 20, 19, 20, 21,
        19, 20, 21, 22], device='cuda:0')

In [18]:
token_logits = mlm_model(inputs, attention_mask=encodings['attention_mask'].to(device)).logits
token_logits.shape

torch.Size([11, 25, 30522])

In [19]:
mask_token_logits = torch.empty(len(mask_token_index), token_logits.shape[2])

for i,ind in enumerate(mask_index_list):
    li_s = mask_index_list[i][0]
    li_e = mask_index_list[i][1]
    ind_s = mask_token_index[li_s]
    ind_e = mask_token_index[li_e - 1] + 1
    #print(li_s, li_e, ind_s, ind_e)
        
    mask_token_logits[li_s:li_e] = token_logits[i, ind_s:ind_e, :]

In [20]:
top_8_tokens = torch.topk(mask_token_logits, 8, dim=1).indices
top_8_tokens.shape

torch.Size([22, 8])

In [21]:
def get_substitutes(substitutes, tokenizer, mlm_model):
    # all substitutes  list of list of token-id (all candidates)
    c_loss = nn.CrossEntropyLoss(reduction='none')
    word_list = []

    # find all possible candidates 
    all_substitutes = []
    for i in range(substitutes.size(0)):
        if len(all_substitutes) == 0:
            lev_i = substitutes[i]
            all_substitutes = [[int(c)] for c in lev_i]
        else:
            lev_i = []
            for all_sub in all_substitutes:
                for j in substitutes[i]:
                    lev_i.append(all_sub + [int(j)])
            all_substitutes = lev_i

    # all_substitutes = all_substitutes[:24]
    all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
    all_substitutes = all_substitutes[:24].to(device)
    
    print(all_substitutes.shape) # (K ^ t, K)

    N, L = all_substitutes.size()
    word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size
    ppl = c_loss(word_predictions.view(N*L, -1), all_substitutes.view(-1)) # [ N*L ] 
    ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N  
    
    _, word_list = torch.sort(ppl)
    word_list = [all_substitutes[i] for i in word_list]
    final_words = []
    for word in word_list[:24]:
        tokens = [tokenizer._convert_id_to_token(int(i)) for i in word]
        text = tokenizer.convert_tokens_to_string(tokens)
        final_words.append(text)
        
    del all_substitutes
    return final_words

In [22]:
for (i, (p_s, p_e)) in enumerate(mask_index_list):
    cur_phrase = ''
    substitutes = top_8_tokens[p_s:p_e]
    final_words = get_substitutes(substitutes, tokenizer, mlm_model)
    for s in substitutes:
        print(tokenizer.convert_ids_to_tokens(s))
    
    print(phrase_masked_list[i])
    for w in final_words[:5]:
        print(phrase_masked_list[i].replace((f' {tokenizer.mask_token}' * (p_e - p_s))[1:], w))
    print()

torch.Size([8, 1])
['go', 'ah', 'yo', 'yahoo', 'freeze', 'sorry', 'oh', 'sh']
 [MASK] ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 go ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 ah ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 oh ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 yo ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 sh ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.

torch.Size([24, 2])
['hey', 'never', 'yu', 'the', 'get', 'ice', 'o', 'ala']
['up', 'go', 'attack', '##cha', 'it', 'you', 'out', '##bba']
 [MASK] [MASK] ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 yu it ! Froz

## Importance Score

In [None]:
target_model = BertForSequenceClassification.from_pretrained('./data/imdb/saved_model/imdb_bert_base_uncased_finetuned_normal').to(device)

In [57]:
# 1. retrieve logits and label from the target model
inputs = tokenizer(entry['text'], return_tensors="pt", truncation=True, max_length=512, return_token_type_ids=False)
orig_logits = target_model(inputs['input_ids'].to(device), inputs['attention_mask'].to(device))[0].squeeze()
orig_probs  = torch.softmax(orig_logits, -1)
orig_label = torch.argmax(orig_probs)
current_prob = orig_probs.max()

In [32]:
from model.importance import get_important_scores
get_important_scores(entry, tokenizer, target_model, orig_label, orig_logits, orig_probs)

RuntimeError: The size of tensor a (2) must match the size of tensor b (15) at non-singleton dimension 0

In [40]:
tokenizer = BertTokenizer.from_pretrained(model_name)

In [53]:
# return units masked with UNK at each position in the sequence
def _get_unk_masked(units):
    len_text = len(units)
    masked_units = []
    for i in range(len_text - 1):
        masked_units.append(units[0:i] + ['[UNK]'] + units[i + 1:])
    
    # list of masked basic units
    return masked_units

'''
input units should be phrase tokens
'''
def get_important_scores(units, tgt_model, orig_prob, orig_label, orig_probs, tokenizer, batch_size=8, max_length=512):
    masked_units = _get_unk_masked(units)
    texts = [' '.join(units) for units in masked_units]  # list of text of masked units
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print(tokenizer)
    encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=max_length, return_token_type_ids=False, return_tensors='pt')
    
    eval_data = TensorDataset(encodings['input_ids'], encodings['attention_mask'])

    # Run prediction for full data
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=batch_size)
    leave_1_probs = []
    
    tgt_model.eval() #make sure in inference stage
    
    with torch.no_grad():
        for batch in eval_dataloader:
            input_ids = batch[0].to(device)      # input ids
            attention_mask = batch[1].to(device) # attention mask
        
            leave_1_prob_batch = tgt_model(input_ids, attention_mask=attention_mask)[0]
            leave_1_probs.append(leave_1_prob_batch)
        
    leave_1_probs = torch.cat(leave_1_probs, dim=0)  # words, num-label
    leave_1_probs = torch.softmax(leave_1_probs, -1)
    leave_1_probs_argmax = torch.argmax(leave_1_probs, dim=-1)
    import_scores = (orig_prob
                     - leave_1_probs[:, orig_label] # how the probability of original label decreases
                     +
                     (leave_1_probs_argmax != orig_label).float() # new label not equal to original label
                     * (leave_1_probs.max(dim=-1)[0] - torch.index_select(orig_probs, 0, leave_1_probs_argmax))
                     ).data.cpu().numpy()           # probability of changed label

    return import_scores

In [60]:
importance = get_important_scores(entry['phrases'], target_model, current_prob, orig_label, orig_probs, tokenizer, batch_size=8, max_length=512)

PreTrainedTokenizer(name_or_path='bert-large-uncased-whole-word-masking', vocab_size=30522, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})


In [71]:
sorted_indices = torch.argsort(torch.tensor(importance), dim=-1, descending=True)

In [75]:
import numpy
sorted_units = np.array(units)[sorted_indices]

In [76]:
[(u,i) for (u,i) in zip(sorted_units, importance[sorted_indices])]

[('what a pity', 0.027897894),
 ('!', 0.010043323),
 ('bad', 0.007974029),
 ('frozen 2', 0.00469023),
 ('is', 0.0031422377),
 ('possibly', 0.003043294),
 ('its chaotic production process', 0.0026093125),
 ('its predecessor', 0.001532495),
 ('due', 0.0013412237),
 ('compared', 0.0013231635),
 ('to', 0.001247406),
 ('to', 0.00043827295),
 ('is', 0.00033521652),
 (',', -0.0005329251),
 ('which', -0.0006894469)]