In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
# from transformers import AutoTokenizer, AutoModel
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import pickle
from utils import EncodingDataset, get_stopwords
from pathlib import Path
import itertools
from termcolor import colored
import time
import copy

In [7]:
torch.manual_seed(2020)

<torch._C.Generator at 0x7fa167793170>

In [8]:
# other option: imdb_data = load_dataset("imdb")
def read_imdb_split(split_dir):
    split_dir = Path(split_dir)
    texts = []
    labels = []
    for label_dir in ["pos", "neg"]:
        for text_file in (split_dir/label_dir).iterdir():
            texts.append(text_file.read_text())
            labels.append(0 if label_dir is "neg" else 1)

    return texts, labels

In [9]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
model_untrained = copy.deepcopy(model)

model.load_state_dict(torch.load("distilledBERT.pt"))
model.eval()
model_untrained.to(device)
model.to(device)

TRAIN = False
TEST = True
if TRAIN:
    train_texts, train_labels = read_imdb_split('aclImdb/train')
    train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)
    train_encodings = tokenizer(train_texts, truncation=True, padding=True)
    train_dataset = EncodingDataset(train_encodings, train_labels)
#     val_encodings = tokenizer(val_texts, truncation=True, padding=True)
#     val_dataset = EncodingDataset(val_encodings, val_labels)
    model.train()
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    optim = torch.optim.AdamW(model.parameters(), lr=5e-5)

    for epoch in range(3):
        for batch in train_loader:
            optim.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs[0]
            loss.backward()
            optim.step()

    model.eval()

if TEST:
    test_texts, test_labels = read_imdb_split('aclImdb/test')
    test_encodings = tokenizer(test_texts, truncation=True, padding=True)
    test_dataset = EncodingDataset(test_encodings, test_labels)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
    outputs_list = []
    
    with torch.no_grad():
        for i, batch in enumerate(test_loader):

            input_ids = batch['input_ids'].to(device)

            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            outputs_list.append(torch.argmax(outputs[1],dim=1).cpu().detach().numpy())

    pred = np.concatenate(outputs_list)
    print(classification_report(pred, test_labels))

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classi

              precision    recall  f1-score   support

           0       0.92      0.92      0.92     12443
           1       0.92      0.92      0.92     12557

    accuracy                           0.92     25000
   macro avg       0.92      0.92      0.92     25000
weighted avg       0.92      0.92      0.92     25000



## Black box using Importance score

In [25]:
# select an example for play
example_idx = 2020
example_text = test_texts[example_idx]
with torch.no_grad():
    example_tokenized = tokenizer(example_text, truncation=True, padding=True, return_tensors='pt').to(device)
    leave_1_token = get_leave_1_token(example_tokenized)
    example_origin_cls_emb = model.base_model(**example_tokenized)[0][:,0,:]


In [24]:
def get_leave_1_token(tokenized):
    '''
    Replace every token (i.e. tokenized word) in the tokenized sentence by '[UNK]' (unkwown token).
    '''
    output = {}
    token_num = tokenized['input_ids'].shape[1]
    output['input_ids'] = tokenized['input_ids'].repeat(token_num-2,1)
    output['attention_mask'] = tokenized['attention_mask'].repeat(token_num-2,1)
    for i, input_ids1 in enumerate(output['input_ids']):
        output['input_ids'][i, i+1] = 100 # make it to [UNK]
    return output

In [None]:
def compute_importance_score_from_pred_probs(origin_token, leave_1_token, model):
    '''
    TextFooler importance score.
    Dependant on the fine-tuned classification model.
    Not worthy for embedding adv.
    Just a demo.
    '''
    # no enough GPU memory
    with torch.no_grad(): # only need output probability
        orig_probs = model(**origin_token)[0]
        orig_label = torch.argmax(orig_probs)
        orig_prob = torch.max(orig_probs)
        
        leave_1_probs = model(**leave_1_token)[0]
        leave_1_probs_argmax = torch.argmax(leave_1_probs, dim=-1)
        
        import_scores = (orig_prob - leave_1_probs[:, orig_label]).cpu().numpy()
        
        import_scores += ((leave_1_probs_argmax != orig_label).float() * 
                         (leave_1_probs.max(dim=-1)[0] - torch.index_select(orig_probs, 1, leave_1_probs_argmax).reshape(-1))).cpu().numpy()
        return import_scores

In [None]:
def compute_importance_score_from_cls_emb(origin_token, leave_1_token, model, lamb=1):
    '''
    importance score for classification embedding: 
    MSE(leave_1_cls_emb, origin_cls_emb) - lamb * cosine_similarity(leave_1_cls_emb, origin_cls_emb)
    ========= Remark ==========
    seems like lamb has no influence on the import_scores, may because the similarity are close...
    '''
    # no enough GPU memory
    with torch.no_grad(): # only need output probability
        example_origin_cls_emb = model.base_model(**origin_token)[0][0][0]
        leave_1_cls_emb = model.base_model(**leave_1_token)[0][:,0,:]
        cos_sim = torch.matmul(leave_1_cls_emb, example_origin_cls_emb)/torch.norm(leave_1_cls_emb, dim=1)/torch.norm(example_origin_cls_emb)#.reshape(-1, 1))
        MSE = torch.norm(example_origin_cls_emb - leave_1_cls_emb, dim=1)
        import_scores = MSE - lamb*cos_sim
        
        return import_scores

In [357]:
def get_index_perturb_index(import_scores, threshold=None, select_n=None, get_important=True):
    '''
    Get some most important words.
    we can set a treshold of selected words' score (using threshold), or a number of selected words(using select_n)
    '''
    tokenid_perturb = []
    if threshold is not None:
        for idx, score in sorted(enumerate(import_scores), key=lambda x: x[1], reverse=get_important):
            try:
                if score > threshold: # and text_ls[idx] not in stop_words_set:
                    tokenid_perturb.append(idx)
            except:
                print(idx)
        tokenid_perturb = np.array(tokenid_perturb)
    elif select_n is not None:
        
        tokenid_perturb = np.array(sorted(enumerate(import_scores), key=lambda x: x[1], reverse=get_important))[:select_n, 0]
    
    return tokenid_perturb

In [358]:
import_scores_pred_probs = compute_importance_score_from_pred_probs(example_tokenized, leave_1_token, model)
tokenid_perturb_ids_pred_probs = get_index_perturb_index(import_scores_pred_probs, select_n=10)

In [359]:
for i,w  in enumerate(example_tokenized['input_ids'][0]):
    if i in tokenid_perturb_ids_pred_probs:
        print(colored(tokenizer.decode(w.reshape(1).cpu().numpy()) + "[{:.2f}]".format(import_scores_pred_probs[i].item()) , 'red'), end=" ")
    else:
        print(tokenizer.decode(w.reshape(1).cpu().numpy()), end=" ")

[CLS] after the atomic bomb hits hiroshima , charred bodies lie all around , def ##or ##med victims attempt to communicate with relatives who can ' t even recognize them , and one person after another dies of radiation sickness . this black and white film , however sad and scary , is not without humor . the story revolves around a young woman ya ##su ##ko , who was hit by black rain after the explosion . she is trying to get married , but everyone keeps dying , and people are worried the same will happen to her . after finding a suitable mate ( who is losing his mind after being in the war for too long ) , she ends up showing signs of radiation sickness . this [31mfilm[1.39][0m [31mis[1.48][0m [31ma[4.19][0m [31mgreat[3.38][0m [31mportrayal[0.83][0m of the atomic attacks on japan , it will fright ##en you , and will perhaps make you [31mcry[0.74][0m . the [31macting[1.11][0m [31mis[1.21][0m [31mgood[0.77][0m [31m,[0.77][0m not overly dramatic like many other w ##w #

In [360]:
import_scores_cls_emb = compute_importance_score_from_cls_emb(example_tokenized, leave_1_token, model, lamb=0)
tokenid_perturb_ids_cls_emb = get_index_perturb_index(import_scores_cls_emb, select_n=10, get_important=True)

In [361]:
for i,w  in enumerate(example_tokenized['input_ids'][0]):
    if i in tokenid_perturb_ids_cls_emb:
        print(colored(tokenizer.decode(w.reshape(1).cpu().numpy()) + "[{:.2f}]".format(import_scores_cls_emb[i].item()), 'red'), end=" ")
    else:
        print(tokenizer.decode(w.reshape(1).cpu().numpy()), end=" ")

[CLS] after the atomic bomb hits hiroshima , charred bodies lie all around , def ##or ##med victims attempt to communicate with relatives who can ' t even recognize them , and one person after another dies of radiation sickness . this black and white film , however sad and scary [31m,[2.46][0m [31mis[2.79][0m [31mnot[4.13][0m [31mwithout[12.79][0m humor . the story revolves around a young woman ya ##su ##ko , who was hit by black rain after the explosion . she is trying to get married , but everyone keeps dying , and people are worried the same will happen to her . after finding a suitable mate ( who is losing his mind after being in the war for too long ) , she ends up showing signs of radiation sickness . this [31mfilm[3.66][0m [31mis[3.63][0m [31ma[9.59][0m [31mgreat[7.87][0m portrayal of the atomic attacks on japan , it will fright ##en you , and will perhaps make you cry . the [31macting[2.76][0m [31mis[3.16][0m good , not overly dramatic like many other w ##w 

In [45]:
num_embedding = model.base_model.get_input_embeddings().num_embeddings

In [73]:
z = torch.rand([num_embedding,1]).requires_grad_(True).to(device)

In [74]:
input_embeddings = model.base_model.get_input_embeddings()(torch.tensor([list(np.arange(num_embedding))]).long().to(device))

In [87]:
torch.mul(torch.softmax(z, dim=0), input_embeddings.squeeze()).shape

torch.Size([30522, 768])

In [92]:
model.base_model(inputs_embeds=input_embeddings[:2])

RuntimeError: CUDA out of memory. Tried to allocate 41.65 GiB (GPU 0; 11.91 GiB total capacity; 1.54 GiB already allocated; 9.69 GiB free; 1.61 GiB reserved in total by PyTorch)

In [20]:
import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

[nltk_data] Downloading package punkt to /home/tian/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/tian/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.


True

In [26]:
text = nltk.word_tokenize(example_text)

In [27]:
nltk.pos_tag(text)

[('After', 'IN'),
 ('the', 'DT'),
 ('atomic', 'JJ'),
 ('bomb', 'NN'),
 ('hits', 'NNS'),
 ('Hiroshima', 'NNP'),
 (',', ','),
 ('charred', 'VBD'),
 ('bodies', 'NNS'),
 ('lie', 'VBP'),
 ('all', 'DT'),
 ('around', 'RB'),
 (',', ','),
 ('deformed', 'VBD'),
 ('victims', 'NNS'),
 ('attempt', 'VB'),
 ('to', 'TO'),
 ('communicate', 'VB'),
 ('with', 'IN'),
 ('relatives', 'NNS'),
 ('who', 'WP'),
 ('ca', 'MD'),
 ("n't", 'RB'),
 ('even', 'RB'),
 ('recognize', 'VB'),
 ('them', 'PRP'),
 (',', ','),
 ('and', 'CC'),
 ('one', 'CD'),
 ('person', 'NN'),
 ('after', 'IN'),
 ('another', 'DT'),
 ('dies', 'NNS'),
 ('of', 'IN'),
 ('radiation', 'NN'),
 ('sickness', 'NN'),
 ('.', '.'),
 ('This', 'DT'),
 ('black', 'JJ'),
 ('and', 'CC'),
 ('white', 'JJ'),
 ('film', 'NN'),
 (',', ','),
 ('however', 'RB'),
 ('sad', 'JJ'),
 ('and', 'CC'),
 ('scary', 'JJ'),
 (',', ','),
 ('is', 'VBZ'),
 ('not', 'RB'),
 ('without', 'IN'),
 ('humor', 'NN'),
 ('.', '.'),
 ('The', 'DT'),
 ('story', 'NN'),
 ('revolves', 'VBZ'),
 ('around', 

## Case study - Replace words

In [362]:
# word dictionary for BERT
word_dict = []
with open("bert-base-uncased-vocab.txt", "r") as f:
    for l in f.readlines():
        word_dict.append(l.strip())

In [None]:
# Brute force replace key words
word_replace_MSE_dict = {}
for i, w in enumerate(word_dict[999:], 1):
    if i % int(len(word_dict[999:])/10) == 0:
        print("Finished {:.2f}%.".format(i / int(len(word_dict[999:])/10) * 10))
    example_text_neg = example_text.replace("good", w)#.replace("good", "good")
    with torch.no_grad():
        example_tokenized_neg_untrained = tokenizer(example_text_neg, truncation=True, padding=True, return_tensors='pt').to(device)
        example_origin_neg_cls_emb_untrained = model_untrained.base_model(**example_tokenized_neg_untrained)[0][0][0]
        example_tokenized_untrained = tokenizer(example_text, truncation=True, padding=True, return_tensors='pt').to(device)
        example_origin_cls_emb_untrained = model_untrained.base_model(**example_tokenized_untrained)[0][:,0,:]
    word_replace_MSE_dict[w] = torch.norm(example_origin_neg_cls_emb_untrained - example_origin_cls_emb_untrained).item()


Finished 10.00%.
Finished 20.00%.
Finished 30.00%.
Finished 40.00%.
Finished 50.00%.
Finished 60.00%.
Finished 70.00%.


In [332]:
sorted(list(word_replace_MSE_dict.items()), key=lambda x:x[1])

[('good', 0.0),
 ('nice', 0.13435426354408264),
 ('great', 0.13668403029441833),
 ('well', 0.1559475213289261),
 ('fine', 0.16519874334335327),
 ('interesting', 0.17041532695293427),
 ('bad', 0.17150020599365234),
 ('best', 0.17429015040397644),
 ('wise', 0.17447488009929657),
 ('fast', 0.17616191506385803),
 ('excellent', 0.17644761502742767),
 ('kind', 0.17768529057502747),
 ('smart', 0.17798015475273132),
 ('strong', 0.17926186323165894),
 ('decent', 0.18114405870437622),
 ('serious', 0.18741022050380707),
 ('convincing', 0.18782706558704376),
 ('fair', 0.18830330669879913),
 ('clever', 0.18837271630764008),
 ('easy', 0.1899218112230301),
 ('hard', 0.18994919955730438),
 ('appealing', 0.18996204435825348),
 ('pretty', 0.19068105518817902),
 ('quality', 0.19214875996112823),
 ('fun', 0.1923743188381195),
 ('perfect', 0.1945510059595108),
 ('high', 0.1957029551267624),
 ('cool', 0.19599246978759766),
 ('quick', 0.19662681221961975),
 ('nicely', 0.1970011442899704),
 ('honest', 0.19803

In [333]:
example_text_neg = example_text.replace("good", "bad")#.replace("good", "good")

In [334]:
with torch.no_grad():
    example_tokenized_neg = tokenizer(example_text_neg, truncation=True, padding=True, return_tensors='pt').to(device)
    example_origin_neg_cls_emb = model.base_model(**example_tokenized_neg)[0][0][0]

In [335]:
torch.norm(example_origin_neg_cls_emb - example_origin_cls_emb)

tensor(5.4837, device='cuda:0')

In [336]:
model.classifier(example_origin_neg_cls_emb)

tensor([-0.4073, -0.2153], device='cuda:0', grad_fn=<AddBackward0>)

In [273]:
model.classifier(example_origin_cls_emb)

tensor([[-0.4173, -0.2270]], device='cuda:0', grad_fn=<AddmmBackward>)

## Antonym - NLTK

In [96]:
# import nltk
# nltk.download('wordnet')
# from nltk.corpus import wordnet as wn

# stop_words = get_stopwords()
# word_id_dict = pickle.load(open("word_id_dict.pkl", "rb"))
# id_word_dict = pickle.load(open("id_word_dict.pkl", "rb"))
# word_sim_embeddings = np.load("word_sim_embeddings.npy")

# greats = wn.synsets("up")
# [great_lemma.antonyms() for great_lemmas in [great.lemmas() for great in greats] for great_lemma in great_lemmas if great_lemma.antonyms()!=[]]

In [None]:
# # NOT WORKING: SIMILARITY only works for synonym
# tokenid_perturb = np.array([135])
# word2replace = tokenizer.decode(example_tokenized['input_ids'][0, tokenid_perturb].cpu().numpy())
# K = 10 # number of potential antonym

# if word2replace not in stop_words:
#     word2replace_id = word_id_dict[word2replace]
#     word_emb = word_sim_embeddings[word2replace_id]
#     K_antonym_id = np.argsort(np.abs(np.dot(word_sim_embeddings, word_emb.reshape(-1, 1)).reshape(-1)))[:K]
#     print(K_antonym_id)
#     for ant_id in K_antonym_id:
#         print(id_word_dict[ant_id])
#     #print(np.sort(np.dot(word_sim_embeddings, word_emb.reshape(-1, 1)).reshape(-1))[:K])