In [2]:
import json
import os
import pickle
import torch
from torch.utils.data import DataLoader
from pathlib import Path
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from transformers import BertTokenizer, BertForSequenceClassification, AdamW

In [3]:
DATASET_NAME = "IMDb"
DATASET_SMALLNAME = "aclImdb"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
REFORMED_DATASET_PATH = f"dataset/{DATASET_NAME}/original_augmented_1x_{DATASET_SMALLNAME}"
OUTPUT_PATH = f"checkpoints/{DATASET_NAME}/original_augmented_1x_output_scheduling_warmup_bs8_2"
REPS_PATH = "reps"
PICKLE_PATH = f"dataset/{DATASET_NAME}/cf_augmented_examples"
TRAIN_SPLIT = "train"
TEST_SPLIT = "test"

In [4]:
class IMDbDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [5]:
# Load dataset
# with open(os.path.join(REFORMED_DATASET_PATH, "train.json")) as f:
#     train = json.load(f) 
# with open(os.path.join(REFORMED_DATASET_PATH, "valid.json")) as f:
#     val = json.load(f)

In [6]:
# train_texts = [d['text'] for d in train]
# train_labels = [d['label'] for d in train]
# val_texts = [d['text'] for d in val]
# val_labels = [d['label'] for d in val]

In [7]:
# Encode dataset
# train_encodings = tokenizer(train_texts, truncation=True, padding=True)
# val_encodings = tokenizer(val_texts, truncation=True, padding=True)

In [8]:
# make dataset class
# train_dataset = IMDbDataset(train_encodings, train_labels)
# val_dataset = IMDbDataset(val_encodings, val_labels)

In [9]:
# train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
# val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [10]:
# Define tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [11]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = BertForSequenceClassification.from_pretrained(os.path.join(OUTPUT_PATH, 'epoch_2'))
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

## Gradient-based Masking

In [12]:
with open(os.path.join(REFORMED_DATASET_PATH, "train.json")) as f:
    data = json.load(f)

In [13]:
train_texts = [d['anchor_text'] for d in data]
train_labels = [d['label'] for d in data]

# Encode dataset
train_encodings = tokenizer(train_texts, truncation=True, padding=True)

# make dataset class
train_dataset = IMDbDataset(train_encodings, train_labels)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)

In [14]:
# Compute gradient at BERT's position_embeddings (discard [cls] and [sep]/[pad])
# Only works for batch_size = 1    
def get_gradient_norms(batch):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    # For CrossEntropy Loss
    _, labels = torch.max(labels, dim=1)

    outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
    loss = outputs[0]
    # print(loss)
    loss.backward(retain_graph=True)
    torch.cuda.empty_cache()

    importances = torch.tensor([]).to(device)
    for pos_index, token_index in zip(range(1, len(input_ids[0])), input_ids[0][1:]):
        if token_index == tokenizer.sep_token_id:
            break

        importance = torch.norm(model._modules['bert']._modules['embeddings']._modules['position_embeddings'].weight.grad[pos_index], 2).float().detach()
        importances = torch.cat((importances, importance.unsqueeze(0)), dim=-1)

    # importances_list.append(importances)
    model._modules['bert']._modules['embeddings']._modules['position_embeddings'].weight.grad = None

    # return importances_list
    return importances

In [15]:
# Compute gradient at BERT's position_embeddings (discard [cls] and [sep]/[pad])
# Only works for batch_size = 1    
def get_token_gradient_norms(batch):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    # For CrossEntropy Loss
    _, labels = torch.max(labels, dim=1)

    outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
    loss = outputs[0]
    # print(loss)
    loss.backward(retain_graph=True)
    torch.cuda.empty_cache()

    importances = torch.tensor([]).to(device)
    for pos_index, token_index in zip(range(1, len(input_ids[0])), input_ids[0][1:]):
        if token_index == tokenizer.sep_token_id:
            break

        importance = torch.norm(model._modules['bert']._modules['embeddings']._modules['word_embeddings'].weight.grad[token_index], 2).float().detach()
        importances = torch.cat((importances, importance.unsqueeze(0)), dim=-1)

    # importances_list.append(importances)
    model._modules['bert']._modules['embeddings']._modules['word_embeddings'].weight.grad = None

    # return importances_list
    return importances

In [16]:
import random
import numpy as np
import seaborn as sn
import matplotlib.pyplot as plt
import pickle
def visualize(words, masks):
    fig, ax = plt.subplots(figsize=(len(words), 1))
    plt.rc('xtick', labelsize=16)
    heatmap = sn.heatmap([masks], xticklabels=words, yticklabels=False, square=True, \
                         linewidths=0.1, cmap='coolwarm', center=0.5, vmin=0, vmax=1)
    plt.xticks(rotation=45)
    plt.show()
    

def mask_causal_words(tokens, importances, topk=1):
    causal_mask = [0 for _ in range(len(tokens))]
    topk_indices = np.argsort(importances)[::-1][:topk]
    for topk_idx in topk_indices:
#         print(topk_idx)
#         print(tokens[topk_idx])
        causal_mask[topk_idx] = 1
    
    return causal_mask

def compute_importances(data_loader, importance_function):
    all_importances = []
    for batch in tqdm(data_loader):
        importances = importance_function(batch)
        all_importances.append(importances)
    return all_importances



def build_causal_mask_with_precomputed(data_loader, all_importances, sampling_ratio, augment_ratio):
    triplets = []
    error_cnt = 0
    for importances, batch in tqdm(zip(all_importances, data_loader)):
        tokens = torch.tensor([x for x in batch['input_ids'][0][1:] if x not in [tokenizer.sep_token_id, tokenizer.pad_token_id]])
        assert tokens.size() == importances.size()
        
        orig_sample = tokenizer.decode(tokens)
        causal_mask = mask_causal_words(tokens.cpu().numpy(), importances.cpu().numpy(), topk=sampling_ratio)
        # visualize(tokens, causal_mask)
        # print(causal_mask)
        
        if 1 not in causal_mask:
            # print(orig_sample[1], cf_sample[1])
            continue
        
        for _ in range(augment_ratio):
            # 모든 causal 단어를 mask, 모든 non-causal 단어를 mask
            if sampling_ratio is None:
                causal_masked_tokens = [tokens[i] if causal_mask[i] == 0 else tokenizer.mask_token_id for i in range(len(tokens))]
                noncausal_masked_tokens = [tokens[i] if causal_mask[i] == 1 else tokenizer.mask_token_id for i in range(len(tokens))]

            # sampling_ratio 갯수 (int) 만큼의 단어를 mask
            elif type(sampling_ratio) == int:
                causal_indices = np.where(np.array(causal_mask) == 1)[0]
                noncausal_indices = np.where(np.array(causal_mask) == 0)[0]

                # print(causal_indices)

                causal_mask_indices = np.random.choice(causal_indices, sampling_ratio)                    
                try:
                    noncausal_mask_indices = np.random.choice(noncausal_indices, max(1, min(sampling_ratio, len(noncausal_indices))))
                    #noncausal_mask_indices = np.random.choice(noncausal_indices, 1)
                except:
                    noncausal_mask_indices = np.random.choice(causal_indices, sampling_ratio)
                    error_cnt += 1

                causal_masked_tokens = [tokens[i] if i not in causal_mask_indices else tokenizer.mask_token_id for i in range(len(tokens))]
                noncausal_masked_tokens = [tokens[i] if i not in noncausal_mask_indices else tokenizer.mask_token_id for i in range(len(tokens))]
            
            # sampling_ratio 비율 (%) 만큼의 단어를 mask
            else:
                pass
                
            causal_masked_sample = tokenizer.decode(causal_masked_tokens)
            noncausal_masked_sample = tokenizer.decode(noncausal_masked_tokens)
            
            _, labels = torch.max(batch['labels'], dim=1)
            if labels[0] == 0: label = 'Negative'
            elif labels[0] == 1: label = 'Positive'
            triplets.append((label, orig_sample, causal_masked_sample, noncausal_masked_sample))
    print(f"Error Cnt: {error_cnt}")    
    return triplets

def mask_propensity_causal_words(tokens, batch, importances, topk=1):
    causal_mask = [0 for _ in range(len(tokens))]
    topk_indices = np.argsort(importances)[::-1]
    err_flag = False
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    orig_outputs = model(input_ids, attention_mask=attention_mask)
    _, orig_prediction = torch.max(orig_outputs[0], dim=1)
    for i, topk_idx in enumerate(topk_indices):
        masked_input_ids = input_ids.clone()
        masked_input_ids[0][topk_idx + 1] = tokenizer.mask_token_id
        masked_outputs = model(masked_input_ids, attention_mask=attention_mask)
        _, masked_prediction = torch.max(masked_outputs[0], dim=1)
        if orig_prediction != masked_prediction:
            causal_mask[topk_idx] = 1
            break
    if 1 not in causal_mask:
        causal_mask[topk_indices[0]] = 1
        err_flag = True
    return causal_mask, err_flag 

def build_propensity_causal_mask_with_precomputed(data_loader, all_importances, sampling_ratio, augment_ratio):
    triplets = []
    error_cnt = 0
    no_flip_cnt = 0
    no_flip_idx = []
    for importances, batch in tqdm(zip(all_importances, data_loader)):
        tokens = torch.tensor([x for x in batch['input_ids'][0][1:] if x not in [tokenizer.sep_token_id, tokenizer.pad_token_id]])
        assert tokens.size() == importances.size()
        
        orig_sample = tokenizer.decode(tokens)
        causal_mask, err_flag = mask_propensity_causal_words(tokens.cpu().numpy(), batch, importances.cpu().numpy(), topk=sampling_ratio)
        no_flip_idx.append(err_flag)
        if err_flag:
            no_flip_cnt += 1
        # visualize(tokens, causal_mask)
        # print(causal_mask)
        
        if 1 not in causal_mask:
            # print(orig_sample[1], cf_sample[1])
            continue
        
        for _ in range(augment_ratio):
            # 모든 causal 단어를 mask, 모든 non-causal 단어를 mask
            if sampling_ratio is None:
                causal_masked_tokens = [tokens[i] if causal_mask[i] == 0 else tokenizer.mask_token_id for i in range(len(tokens))]
                noncausal_masked_tokens = [tokens[i] if causal_mask[i] == 1 else tokenizer.mask_token_id for i in range(len(tokens))]

            # sampling_ratio 갯수 (int) 만큼의 단어를 mask
            elif type(sampling_ratio) == int:
                causal_indices = np.where(np.array(causal_mask) == 1)[0]
                noncausal_indices = np.where(np.array(causal_mask) == 0)[0]

                # print(causal_indices)

                causal_mask_indices = np.random.choice(causal_indices, sampling_ratio)                    
                try:
                    noncausal_mask_indices = np.random.choice(noncausal_indices, max(1, min(sampling_ratio, len(noncausal_indices))))
                    #noncausal_mask_indices = np.random.choice(noncausal_indices, 1)
                except:
                    noncausal_mask_indices = np.random.choice(causal_indices, sampling_ratio)
                    error_cnt += 1

                causal_masked_tokens = [tokens[i] if i not in causal_mask_indices else tokenizer.mask_token_id for i in range(len(tokens))]
                noncausal_masked_tokens = [tokens[i] if i not in noncausal_mask_indices else tokenizer.mask_token_id for i in range(len(tokens))]
            
            # sampling_ratio 비율 (%) 만큼의 단어를 mask
            else:
                pass
                
            causal_masked_sample = tokenizer.decode(causal_masked_tokens)
            noncausal_masked_sample = tokenizer.decode(noncausal_masked_tokens)
            
            _, labels = torch.max(batch['labels'], dim=1)
            if labels[0] == 0: label = 'Negative'
            elif labels[0] == 1: label = 'Positive'
            triplets.append((label, orig_sample, causal_masked_sample, noncausal_masked_sample, err_flag))
    print(f"Error Cnt: {error_cnt}")    
    print(f"No Flip Cnt: {no_flip_cnt}")    
    return triplets, no_flip_idx

def compute_average_importance(data_loader, all_importances):
    all_averaged_importances = []
    importance_dict = dict()
    importance_dict_counter = dict()
    
    for importances, batch in tqdm(zip(all_importances, data_loader)):
        tokens = [x for x in batch['input_ids'][0][1:] if x not in [tokenizer.sep_token_id, tokenizer.pad_token_id]]
        
        for tok_imp, tok in zip(importances, tokens):
            if not tok in importance_dict.keys():
                importance_dict[tok.item()] = 0
                importance_dict_counter[tok.item()] = 0
            importance_dict[tok.item()] += tok_imp
            importance_dict_counter[tok.item()] += 1
            
    
    for importances, batch in tqdm(zip(all_importances, data_loader)):
        tokens = [x for x in batch['input_ids'][0][1:] if x not in [tokenizer.sep_token_id, tokenizer.pad_token_id]]
        averaged_importances = torch.tensor([importance_dict[x.item()]/importance_dict_counter[x.item()] for x in tokens])
        all_averaged_importances.append(averaged_importances)
    
    return all_averaged_importances

In [17]:
if not os.path.exists(PICKLE_PATH):
    os.makedirs(PICKLE_PATH)

#all_importance = compute_importances(train_loader, get_gradient_norms)
#with open("dataset/SST-2/cf_augmented_examples/gradient_importance.pickle", 'wb') as f:
#    pickle.dump(all_importance, f)

if os.path.exists(os.path.join(PICKLE_PATH, "gradient_importance.pickle")):
    with open(os.path.join(PICKLE_PATH, "gradient_importance.pickle"), 'rb') as f:
        all_importance = pickle.load(f)
else:
    all_importance = compute_importances(train_loader, get_gradient_norms)
    with open(os.path.join(PICKLE_PATH, "gradient_importance.pickle"), 'wb') as f:
        pickle.dump(all_importance, f)

In [18]:
averaged_all_importance = compute_average_importance(train_loader, all_importance)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [19]:
with open(os.path.join(PICKLE_PATH, "gradient_averaged_importance.pickle"), 'wb') as f:
        pickle.dump(averaged_all_importance, f)

In [None]:
#for sampling_ratio in [1, 2, 3, 4, 5]:
for sampling_ratio in [1]:
    for augment_ratio in [1]:
        triplets_train, no_flip_idx_train = build_propensity_causal_mask_with_precomputed(train_loader, averaged_all_importance, sampling_ratio=sampling_ratio, augment_ratio=augment_ratio)
        with open(os.path.join(PICKLE_PATH, "triplets_automated_averaged_gradient_propensity_sampling{}_augmenting{}_train.pickle".format(sampling_ratio, augment_ratio)), "wb") as fp:
            pickle.dump(triplets_train, fp)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

In [23]:
with open(os.path.join(PICKLE_PATH, "triplets_automated_averaged_gradient_propensity_sampling{}_augmenting{}_flip_idx_train.pickle".format(sampling_ratio, augment_ratio)), "wb") as fp:
    pickle.dump(triplets_train, fp)

In [24]:
triplets_train[0]

('Negative',
 'hide new secretions from the parental units',
 'hide new secretions from the [MASK] units',
 'hide new secret [MASK] from the parental units',
 True)

# Qualitative Test

In [19]:
with open(os.path.join(PICKLE_PATH, "triplets_automated_averaged_gradient_propensity_sampling1_augmenting1_train.pickle".format(sampling_ratio, augment_ratio)), "rb") as fp:
            data_b = pickle.load(fp)
        
with open(os.path.join(PICKLE_PATH, "triplets_automated_averaged_gradient_sampling1_augmenting1_train.pickle".format(sampling_ratio, augment_ratio)), "rb") as fp:
            data_a = pickle.load(fp)

In [35]:
cnt = 0
diff_idxes = []
for i, (a, b) in enumerate(zip(data_a, data_b)):
    if a[2] != b[2]:
        diff_idxes.append(i)

In [53]:
tmp_i = 3
print("[ORIG]:    " + data_a[diff_idxes[tmp_i]][1])
print()
print("[GRAD]:    " + data_a[diff_idxes[tmp_i]][2])
print()
print("[PROP]:    " + data_b[diff_idxes[tmp_i]][2])

[ORIG]:    goes to absurd lengths

[GRAD]:    goes to absurd [MASK]

[PROP]:    goes to [MASK] lengths
