In [1]:
from pathlib import Path
from pprint import pprint

import torch  
import datasets
from datasets import concatenate_datasets
from tqdm import tqdm
from transformers import (
  AutoTokenizer,
  BertForSequenceClassification,
  AutoModelForMaskedLM,
  pipeline
)

from common.data_utils import get_dataset
from model.tokenizer import PhraseTokenizer
from model.attacker import Attacker

In [94]:
import torch.nn as nn

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

Using cuda


In [3]:
target_model = BertForSequenceClassification.from_pretrained('./data/imdb/saved_model/imdb_bert_base_uncased_finetuned_normal').to(device)

In [4]:
from transformers import BertForMaskedLM, BertTokenizer

In [5]:
model_name = "bert-large-uncased-whole-word-masking"
tokenizer = BertTokenizer.from_pretrained(model_name)
mlm_model = BertForMaskedLM.from_pretrained(model_name).to(device)

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
tgt_seq = "What a pity! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process."
entry = {'text': tgt_seq}

In [7]:
phrase_tok = PhraseTokenizer()
phrase_token_output = phrase_tok.tokenize(entry)

['tagger', 'parser', 'ner', 'merge_noun_chunks', 'merge_entities']


In [8]:
phrase_token_output

{'text': 'what a pity! frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 'words': ['what',
  'a',
  'pity',
  '!',
  'frozen',
  '2',
  'is',
  'bad',
  'compared',
  'to',
  'its',
  'predecessor',
  ',',
  'which',
  'is',
  'possibly',
  'due',
  'to',
  'its',
  'chaotic',
  'production',
  'process',
  '.'],
 'word_offsets': [(0, 4),
  (5, 6),
  (7, 11),
  (11, 12),
  (13, 19),
  (20, 21),
  (22, 24),
  (25, 28),
  (29, 37),
  (38, 40),
  (41, 44),
  (45, 56),
  (56, 57),
  (58, 63),
  (64, 66),
  (67, 75),
  (76, 79),
  (80, 82),
  (83, 86),
  (87, 94),
  (95, 105),
  (106, 113),
  (113, 114)],
 'phrases': ['what a pity',
  '!',
  'frozen 2',
  'is',
  'bad',
  'compared',
  'to',
  'its predecessor',
  ',',
  'which',
  'is',
  'possibly',
  'due',
  'to',
  'its chaotic production process',
  '.'],
 'phrase_offsets': [(0, 11),
  (11, 12),
  (13, 21),
  (22, 24),
  (25, 28),
  (29, 37),
  (38, 40),
  (41, 56),
  (56, 57),
  

In [9]:
p_i = 0
p_s = 0
p_e = phrase_token_output['phrase_offsets'][p_i][1]
p_len = 0
phrase2word = []
new_p = True
word_count = 0
for w_s, w_e in phrase_token_output['word_offsets']:
    
    if new_p:
        p_s = word_count
        new_p = False
    
    if w_e == p_e:
        phrase2word.append([p_s, word_count+1])
        new_p = True
        p_i = min(p_i + 1, len(phrase_token_output['phrase_offsets']) - 1)
        p_e = phrase_token_output['phrase_offsets'][p_i][1]
    
    word_count += 1
        

In [10]:
phrase2word

[[0, 3],
 [3, 4],
 [4, 6],
 [6, 7],
 [7, 8],
 [8, 9],
 [9, 10],
 [10, 12],
 [12, 13],
 [13, 14],
 [14, 15],
 [15, 16],
 [16, 17],
 [17, 18],
 [18, 22],
 [22, 23]]

In [90]:
phrase_masked_list = []
word2char = phrase_token_output['word_offsets']

mask_index_list = []
mask_count = 0
for p_s, p_e in phrase2word:
    if p_e - p_s >= 2:
        c_s = word2char[ p_s ][0]
        c_e = word2char[ p_e - 1][1]
        
        mask_len = p_e - p_s
        phrase_masked_list.append(tgt_seq[0:c_s] + ' [MASK]' * mask_len + ' ' + tgt_seq[c_e:])
        mask_index_list.append([mask_count, mask_count + mask_len])
        mask_count += mask_len

In [43]:
phrase_masked_list

[' [MASK] [MASK] [MASK] ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 'What a pity!  [MASK] [MASK]  is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 'What a pity! Frozen 2 is bad compared to  [MASK] [MASK] , which is possibly due to its chaotic production process.',
 'What a pity! Frozen 2 is bad compared to its predecessor, which is possibly due to  [MASK] [MASK] [MASK] [MASK] .']

In [44]:
encodings = tokenizer(phrase_masked_list, truncation=True, padding=True, return_token_type_ids=False, return_tensors='pt')

In [45]:
encodings

{'input_ids': tensor([[  101,   103,   103,   103,   999,  7708,  1016,  2003,  2919,  4102,
          2000,  2049,  8646,  1010,  2029,  2003,  4298,  2349,  2000,  2049,
         19633,  2537,  2832,  1012,   102],
        [  101,  2054,  1037, 12063,   999,   103,   103,  2003,  2919,  4102,
          2000,  2049,  8646,  1010,  2029,  2003,  4298,  2349,  2000,  2049,
         19633,  2537,  2832,  1012,   102],
        [  101,  2054,  1037, 12063,   999,  7708,  1016,  2003,  2919,  4102,
          2000,   103,   103,  1010,  2029,  2003,  4298,  2349,  2000,  2049,
         19633,  2537,  2832,  1012,   102],
        [  101,  2054,  1037, 12063,   999,  7708,  1016,  2003,  2919,  4102,
          2000,  2049,  8646,  1010,  2029,  2003,  4298,  2349,  2000,   103,
           103,   103,   103,  1012,   102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [76]:
inputs = encodings['input_ids'].to(device)
mask_token_index = torch.where(inputs == tokenizer.mask_token_id)[1]

In [77]:
mask_token_index

tensor([ 1,  2,  3,  5,  6, 11, 12, 19, 20, 21, 22], device='cuda:0')

In [48]:
tokenizer.convert_ids_to_tokens(encodings['input_ids'][3])

['[CLS]',
 'what',
 'a',
 'pity',
 '!',
 'frozen',
 '2',
 'is',
 'bad',
 'compared',
 'to',
 'its',
 'predecessor',
 ',',
 'which',
 'is',
 'possibly',
 'due',
 'to',
 '[MASK]',
 '[MASK]',
 '[MASK]',
 '[MASK]',
 '.',
 '[SEP]']

In [78]:
token_logits = mlm_model(inputs, attention_mask=encodings['attention_mask'].to(device)).logits
token_logits.shape

torch.Size([4, 25, 30522])

In [91]:
mask_token_logits = torch.empty(len(mask_token_index), token_logits.shape[2])

for i,ind in enumerate(mask_index_list):
    li_s = mask_index_list[i][0]
    li_e = mask_index_list[i][1]
    ind_s = mask_token_index[li_s]
    ind_e = mask_token_index[li_e - 1] + 1
    #print(li_s, li_e, ind_s, ind_e)
        
    mask_token_logits[li_s:li_e] = token_logits[i, ind_s:ind_e, :]

In [92]:
mask_index_list

[[0, 3], [3, 5], [5, 7], [7, 11]]

In [87]:
mask_token_logits.shape

torch.Size([11, 30522])

In [88]:
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices
top_5_tokens.shape

In [89]:
top_5_tokens

tensor([[ 1996, 14021,  2196,  8038,  9805],
        [  999,  1011, 23644,  7507,  1010],
        [ 2080,  2050, 10930,  2017,  2213],
        [ 1996,  2023,  2049,  2008,  2117],
        [ 2143,  8297,  3185,  2201,  2208],
        [ 2049,  7708,  1996,  3025,  3256],
        [ 8646,  2434,  1015, 16372,  2034],
        [ 1996,  2049,  1037,  3768,  2019],
        [ 3768,  1997,  1043,  2058,  3532],
        [ 1997, 15909,  1998,  1999,  2208],
        [ 2015, 11247,  4180,  3314,  3194]])

In [111]:
mask_index_list

[[0, 3], [3, 5], [5, 7], [7, 11]]

In [128]:
for (i, (p_s, p_e)) in enumerate(mask_index_list):
    cur_phrase = ''
    substitutes = top_5_tokens[p_s:p_e]
    final_words = get_substitutes(substitutes, tokenizer, mlm_model)
    for s in substitutes:
        print(tokenizer.convert_ids_to_tokens(s))
    
    print(phrase_masked_list[i])
    for w in final_words[:5]:
        print(phrase_masked_list[i].replace((f' {tokenizer.mask_token}' * (p_e - p_s))[1:], w))
    print()

torch.Size([125, 3])
['the', 'sh', 'never', 'ya', 'yu']
['!', '-', '##hh', '##cha', ',']
['##o', '##a', 'yo', 'you', '##m']
 [MASK] [MASK] [MASK] ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 the !m ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 thehha ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 thechao ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 thecha you ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 the ,o ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.

torch.Size([25, 2])
['the', 'this', 'its', 'that', 'second']
['film', 'sequel', 'movie', 'album', 'game']
What a pity!  [MASK] [MASK]  is bad compared to its predecessor, which is poss

In [123]:
def get_substitutes(substitutes, tokenizer, mlm_model):
    # all substitutes  list of list of token-id (all candidates)
    c_loss = nn.CrossEntropyLoss(reduction='none')
    word_list = []

    # find all possible candidates 
    all_substitutes = []
    for i in range(substitutes.size(0)):
        if len(all_substitutes) == 0:
            lev_i = substitutes[i]
            all_substitutes = [[int(c)] for c in lev_i]
        else:
            lev_i = []
            for all_sub in all_substitutes:
                for j in substitutes[i]:
                    lev_i.append(all_sub + [int(j)])
            all_substitutes = lev_i

    # all_substitutes = all_substitutes[:24]
    all_substitutes = torch.tensor(all_substitutes) # [ N, L ]
    all_substitutes = all_substitutes.to(device)
    
    print(all_substitutes.shape) # (K ^ t, K)

    N, L = all_substitutes.size()
    word_predictions = mlm_model(all_substitutes)[0] # N L vocab-size
    ppl = c_loss(word_predictions.view(N*L, -1), all_substitutes.view(-1)) # [ N*L ] 
    ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1)) # N  
    
    _, word_list = torch.sort(ppl)
    word_list = [all_substitutes[i] for i in word_list]
    final_words = []
    for word in word_list[:24]:
        tokens = [tokenizer._convert_id_to_token(int(i)) for i in word]
        text = tokenizer.convert_tokens_to_string(tokens)
        final_words.append(text)
    return final_words

In [93]:
phrase_masked_list[0].replace(' [MASK]' * (p_e - p_s) + ' ', )

' [MASK] [MASK] [MASK] ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.'

In [41]:
for (i,token) in enumerate(top_5_tokens):
    print(phrase_masked_list[i])
    for t in tokenizer.convert_ids_to_tokens(token):
        print(phrase_masked_list[i].replace(tokenizer.mask_token, t))
    print()


 [MASK] ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 go ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 ah ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 yo ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 yahoo ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.
 freeze ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.

What a pity!  [MASK]  is bad compared to its predecessor, which is possibly due to its chaotic production process.
What a pity!  it  is bad compared to its predecessor, which is possibly due to its chaotic production process.
What a pity!  this  is bad compared to its predecessor, which is possibly due to its chaotic production process.
W