In [8]:
from pathlib import Path
from pprint import pprint

import torch  
import datasets
from datasets import concatenate_datasets
from tqdm import tqdm
from transformers import (
  AutoTokenizer,
  BertForSequenceClassification,
  AutoModelForMaskedLM,
  pipeline
)

from common.data_utils import get_dataset
from model.tokenizer import PhraseTokenizer
from model.attacker import Attacker

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

Using cuda


In [3]:
target_model = BertForSequenceClassification.from_pretrained('./data/imdb/saved_model/imdb_bert_base_uncased_finetuned_normal').to(device)

In [4]:
from transformers import BertForMaskedLM, BertTokenizer

In [52]:
model_name = "bert-large-uncased-whole-word-masking"
tokenizer = BertTokenizer.from_pretrained(model_name)
mlm_model = BertForMaskedLM.from_pretrained(model_name).to(device)

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [24]:
tgt_seq = "What a pity! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process."
entry = {'text': tgt_seq}

In [27]:
phrase_tok = PhraseTokenizer()
phrase_token_output = phrase_tok.tokenize(entry)

['tagger', 'parser', 'ner', 'merge_noun_chunks', 'merge_entities']


In [28]:
phrase_token_output

{'text': 'what a pity! frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 'words': ['what',
  'a',
  'pity',
  '!',
  'frozen',
  '2',
  'is',
  'bad',
  'compared',
  'to',
  'its',
  'predecessor',
  ',',
  'which',
  'is',
  'possibly',
  'due',
  'to',
  'its',
  'chaotic',
  'production',
  'process',
  '.'],
 'word_offsets': [(0, 4),
  (5, 6),
  (7, 11),
  (11, 12),
  (13, 19),
  (20, 21),
  (22, 24),
  (25, 28),
  (29, 37),
  (38, 40),
  (41, 44),
  (45, 56),
  (56, 57),
  (58, 63),
  (64, 66),
  (67, 75),
  (76, 79),
  (80, 82),
  (83, 86),
  (87, 94),
  (95, 105),
  (106, 113),
  (113, 114)],
 'phrases': ['what a pity',
  '!',
  'frozen 2',
  'is',
  'bad',
  'compared',
  'to',
  'its predecessor',
  ',',
  'which',
  'is',
  'possibly',
  'due',
  'to',
  'its chaotic production process',
  '.'],
 'phrase_offsets': [(0, 11),
  (11, 12),
  (13, 21),
  (22, 24),
  (25, 28),
  (29, 37),
  (38, 40),
  (41, 56),
  (56, 57),
  

In [36]:
p_i = 0
p_s = 0
p_e = phrase_token_output['phrase_offsets'][p_i][1]
p_len = 0
phrase2word = []
new_p = True
word_count = 0
for w_s, w_e in phrase_token_output['word_offsets']:
    
    if new_p:
        p_s = word_count
        new_p = False
    
    if w_e == p_e:
        phrase2word.append([p_s, word_count+1])
        new_p = True
        p_i = min(p_i + 1, len(phrase_token_output['phrase_offsets']) - 1)
        p_e = phrase_token_output['phrase_offsets'][p_i][1]
    
    word_count += 1
        

In [37]:
phrase2word

[[0, 3],
 [3, 4],
 [4, 6],
 [6, 7],
 [7, 8],
 [8, 9],
 [9, 10],
 [10, 12],
 [12, 13],
 [13, 14],
 [14, 15],
 [15, 16],
 [16, 17],
 [17, 18],
 [18, 22],
 [22, 23]]

In [48]:
phrase_masked_list = []
word2char = phrase_token_output['word_offsets']

for p_s, p_e in phrase2word:
    if p_e - p_s >= 2:
        c_s = word2char[ p_s ][0]
        c_e = word2char[ p_e - 1][1]
        
        phrase_masked_list.append(tgt_seq[0:c_s] + ' [MASK] ' + tgt_seq[c_e:])

In [49]:
phrase_masked_list

[' [MASK] ! Frozen 2 is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 'What a pity!  [MASK]  is bad compared to its predecessor, which is possibly due to its chaotic production process.',
 'What a pity! Frozen 2 is bad compared to  [MASK] , which is possibly due to its chaotic production process.',
 'What a pity! Frozen 2 is bad compared to its predecessor, which is possibly due to  [MASK] .']

In [50]:
inputs = tokenizer.encode(phrase_masked_list, return_tensors="pt").to(device)
mask_token_index = torch.where(inputs == tokenizer.mask_token_id)[1]

In [54]:
encodings = tokenizer(phrase_masked_list, truncation=True, padding=True, return_token_type_ids=False, return_tensors='pt')

In [55]:
encodings

{'input_ids': tensor([[  101,   103,   999,  7708,  1016,  2003,  2919,  4102,  2000,  2049,
          8646,  1010,  2029,  2003,  4298,  2349,  2000,  2049, 19633,  2537,
          2832,  1012,   102,     0],
        [  101,  2054,  1037, 12063,   999,   103,  2003,  2919,  4102,  2000,
          2049,  8646,  1010,  2029,  2003,  4298,  2349,  2000,  2049, 19633,
          2537,  2832,  1012,   102],
        [  101,  2054,  1037, 12063,   999,  7708,  1016,  2003,  2919,  4102,
          2000,   103,  1010,  2029,  2003,  4298,  2349,  2000,  2049, 19633,
          2537,  2832,  1012,   102],
        [  101,  2054,  1037, 12063,   999,  7708,  1016,  2003,  2919,  4102,
          2000,  2049,  8646,  1010,  2029,  2003,  4298,  2349,  2000,   103,
          1012,   102,     0,     0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 

In [57]:
mask_token_index = torch.where(encodings['input_ids'] == tokenizer.mask_token_id)[1]

In [58]:
mask_token_index

tensor([ 1,  5, 11, 19])

In [63]:
tokenizer.convert_ids_to_tokens(encodings['input_ids'][3])

['[CLS]',
 'what',
 'a',
 'pity',
 '!',
 'frozen',
 '2',
 'is',
 'bad',
 'compared',
 'to',
 'its',
 'predecessor',
 ',',
 'which',
 'is',
 'possibly',
 'due',
 'to',
 '[MASK]',
 '.',
 '[SEP]',
 '[PAD]',
 '[PAD]']

In [12]:


token_logits = mlm_model(inputs).logits
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(sequence.replace(tokenizer.mask_token, tokenizer.decode([token])))


Distilled models are smaller than the models they mimic. Using them instead of the large versions would help smaller out carbon footprint.
Distilled models are cheaper than the models they mimic. Using them instead of the large versions would help cheaper out carbon footprint.
Distilled models are simpler than the models they mimic. Using them instead of the large versions would help simpler out carbon footprint.
Distilled models are larger than the models they mimic. Using them instead of the large versions would help larger out carbon footprint.
Distilled models are lighter than the models they mimic. Using them instead of the large versions would help lighter out carbon footprint.


In [10]:
f"{tokenizer.mask_token}"

'[MASK]'

In [11]:
phrase_tok = PhraseTokenizer()

['tagger', 'parser', 'ner', 'merge_noun_chunks', 'merge_entities']


In [14]:
entry = {}
entry['text'] = sequence

In [16]:
tokenizer.tokenize(sequence)

['di',
 '##sti',
 '##lled',
 'models',
 'are',
 '[MASK]',
 'than',
 'the',
 'models',
 'they',
 'mimic',
 '.',
 'using',
 'them',
 'instead',
 'of',
 'the',
 'large',
 'versions',
 'would',
 'help',
 '[MASK]',
 'out',
 'carbon',
 'footprint',
 '.']

In [15]:
phrase_tok.tokenize(entry)

{'text': 'distilled models are [mask] than the models they mimic. using them instead of the large versions would help [mask] out carbon footprint.',
 'words': ['distilled',
  'models',
  'are',
  '[',
  'mask',
  ']',
  'than',
  'the',
  'models',
  'they',
  'mimic',
  '.',
  'using',
  'them',
  'instead',
  'of',
  'the',
  'large',
  'versions',
  'would',
  'help',
  '[',
  'mask',
  ']',
  'out',
  'carbon',
  'footprint',
  '.'],
 'word_offsets': [(0, 9),
  (10, 16),
  (17, 20),
  (21, 22),
  (22, 26),
  (26, 27),
  (28, 32),
  (33, 36),
  (37, 43),
  (44, 48),
  (49, 54),
  (54, 55),
  (56, 61),
  (62, 66),
  (67, 74),
  (75, 77),
  (78, 81),
  (82, 87),
  (88, 96),
  (97, 102),
  (103, 107),
  (108, 109),
  (109, 113),
  (113, 114),
  (115, 118),
  (119, 125),
  (126, 135),
  (135, 136)],
 'phrases': ['distilled models',
  'are',
  '[mask',
  ']',
  'than',
  'the models',
  'they',
  'mimic',
  '.',
  'using',
  'them',
  'instead',
  'of',
  'the large versions',
  'would',