In [1]:
import json
import os
import pickle
import torch
from torch.utils.data import DataLoader
from pathlib import Path
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from transformers import BertTokenizer, BertForSequenceClassification, AdamW

In [11]:
DATASET_NAME = "CFNLI"
DATASET_SMALLNAME = "cfnli"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
REFORMED_DATASET_PATH = f"dataset/{DATASET_NAME}/original_augmented_1x_{DATASET_SMALLNAME}"
OUTPUT_PATH = f"checkpoints/{DATASET_NAME}/original_augmented_1x_output_scheduling_warmup_lambda_01_0"
REPS_PATH = "reps"
PICKLE_PATH = f"dataset/{DATASET_NAME}/cf_augmented_examples"
TRAIN_SPLIT = "train"
TEST_SPLIT = "test"
NUM_LABELS = 3

In [3]:
class IMDbDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [4]:
# Load dataset
# with open(os.path.join(REFORMED_DATASET_PATH, "train.json")) as f:
#     train = json.load(f) 
# with open(os.path.join(REFORMED_DATASET_PATH, "valid.json")) as f:
#     val = json.load(f)

In [5]:
# train_texts = [d['text'] for d in train]
# train_labels = [d['label'] for d in train]
# val_texts = [d['text'] for d in val]
# val_labels = [d['label'] for d in val]

In [6]:
# Encode dataset
# train_encodings = tokenizer(train_texts, truncation=True, padding=True)
# val_encodings = tokenizer(val_texts, truncation=True, padding=True)

In [7]:
# make dataset class
# train_dataset = IMDbDataset(train_encodings, train_labels)
# val_dataset = IMDbDataset(val_encodings, val_labels)

In [8]:
# train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
# val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [9]:
# Define tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [10]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = BertForSequenceClassification.from_pretrained(os.path.join(OUTPUT_PATH, 'best_epoch'), num_labels=NUM_LABELS)
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

## Gradient-based Masking

In [12]:
with open(os.path.join(REFORMED_DATASET_PATH, "train.json")) as f:
    data = json.load(f)

In [13]:
def split_sentences(text):
    sp = text.split(" [SEP] ")
    if len(sp) > 1:
        return sp
    else:
        return sp[0]

train_texts = [split_sentences(d['anchor_text']) for d in data]
train_labels = [d['label'] for d in data]

# Encode dataset
train_encodings = tokenizer(train_texts, truncation=True, padding=True)

# make dataset class
train_dataset = IMDbDataset(train_encodings, train_labels)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)

# Gradient Scoring Functions

In [14]:
# Compute gradient at BERT's position_embeddings (discard [cls] and [sep]/[pad])
# Only works for batch_size = 1    
def get_gradient_norms(batch):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    token_type_ids = batch['token_type_ids'].to(device)
    labels = batch['labels'].to(device)
    # For CrossEntropy Loss
    _, labels = torch.max(labels, dim=1)

    outputs = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels)
    loss = outputs[0]
    # print(loss)
    loss.backward(retain_graph=True)
    torch.cuda.empty_cache()

    importances = torch.tensor([]).to(device)
    for pos_index, token_index in zip(range(1, len(input_ids[0])), input_ids[0][1:]):
        if token_index == tokenizer.pad_token_id:
            break

        importance = torch.norm(model._modules['bert']._modules['embeddings']._modules['position_embeddings'].weight.grad[pos_index], 2).float().detach()
        importances = torch.cat((importances, importance.unsqueeze(0)), dim=-1)

    # importances_list.append(importances)
    model._modules['bert']._modules['embeddings']._modules['position_embeddings'].weight.grad = None

    # return importances_list
    return importances

# Compute gradient at BERT's position_embeddings (discard [cls] and [sep]/[pad])
# Only works for batch_size = 1    
def get_token_gradient_norms(batch):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    token_type_ids = batch['token_type_ids'].to(device)
    labels = batch['labels'].to(device)
    # For CrossEntropy Loss
    _, labels = torch.max(labels, dim=1)

    outputs = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels)
    loss = outputs[0]
    # print(loss)
    loss.backward(retain_graph=True)
    torch.cuda.empty_cache()

    importances = torch.tensor([]).to(device)
    for pos_index, token_index in zip(range(1, len(input_ids[0])), input_ids[0][1:]):
        if token_index == tokenizer.pad_token_id:
            break

        importance = torch.norm(model._modules['bert']._modules['embeddings']._modules['word_embeddings'].weight.grad[token_index], 2).float().detach()
        importances = torch.cat((importances, importance.unsqueeze(0)), dim=-1)

    # importances_list.append(importances)
    model._modules['bert']._modules['embeddings']._modules['word_embeddings'].weight.grad = None

    # return importances_list
    return importances

# Utility Functions

In [15]:
import random
import numpy as np
import seaborn as sn
import matplotlib.pyplot as plt
import pickle
def visualize(words, masks):
    fig, ax = plt.subplots(figsize=(len(words), 1))
    plt.rc('xtick', labelsize=16)
    heatmap = sn.heatmap([masks], xticklabels=words, yticklabels=False, square=True, \
                         linewidths=0.1, cmap='coolwarm', center=0.5, vmin=0, vmax=1)
    plt.xticks(rotation=45)
    plt.show()
    

def mask_causal_words(tokens, importances, topk=1):
    causal_mask = [0 for _ in range(len(tokens))]
    topk_indices = np.argsort(importances)[::-1][:topk]
    for topk_idx in topk_indices:
#         print(topk_idx)
#         print(tokens[topk_idx])
        causal_mask[topk_idx] = 1
    
    return causal_mask

"""
#def mask_causal_words(tokens, importances, topk=1):
def mask_causal_words(tokens, importances, topk=-1):
    causal_mask = [0 for _ in range(len(tokens))]
    topk_indices = np.argsort(importances)[::-1]
    
    cnt = 0
    for topk_idx in topk_indices:
#         print(topk_idx)
#         print(tokens[topk_idx])
        if token[topk_idx] == tokenizer.sep_token_id:
            print("DEBUG: SEP is detected!!")
            continue
        causal_mask[topk_idx] = 1
        cnt += 1
        if cnt == topk:
            break
    
    return causal_mask
"""


def compute_importances(data_loader, importance_function):
    all_importances = []
    for batch in tqdm(data_loader):
        importances = importance_function(batch)
        all_importances.append(importances)
    return all_importances


def compute_average_importance(data_loader, all_importances):
    all_averaged_importances = []
    importance_dict = dict()
    importance_dict_counter = dict()
    
    for importances, batch in tqdm(zip(all_importances, data_loader)):
        tokens = [x for x in batch['input_ids'][0][1:] if x not in [tokenizer.pad_token_id]]
        
        for tok_imp, tok in zip(importances, tokens):
            assert tok_imp > 0
            if not tok in importance_dict.keys():
                importance_dict[tok.item()] = 0
                importance_dict_counter[tok.item()] = 0
            importance_dict[tok.item()] += tok_imp
            importance_dict_counter[tok.item()] += 1
    
    ### [SEP] Token is filtered!!!! ###
    importance_dict[tokenizer.sep_token_id] = 0
    
    for importances, batch in tqdm(zip(all_importances, data_loader)):
        tokens = [x for x in batch['input_ids'][0][1:] if x not in [tokenizer.pad_token_id]]
        averaged_importances = torch.tensor([importance_dict[x.item()]/importance_dict_counter[x.item()] for x in tokens])
        all_averaged_importances.append(averaged_importances)
    
    return all_averaged_importances


# Compute Gradient-based Causal Masks

In [16]:
def build_causal_mask_with_precomputed(data_loader, all_importances, sampling_ratio, augment_ratio):
    triplets = []
    error_cnt = 0
    for importances, batch in tqdm(zip(all_importances, data_loader)):
        tokens = torch.tensor([x for x in batch['input_ids'][0][1:] if x not in [tokenizer.pad_token_id]])
        assert tokens.size() == importances.size()
        
        orig_sample = tokenizer.decode(tokens[:-1])
        causal_mask = mask_causal_words(tokens.cpu().numpy(), importances.cpu().numpy(), topk=sampling_ratio)
        # visualize(tokens, causal_mask)
        # print(causal_mask)
        
        if 1 not in causal_mask:
            # print(orig_sample[1], cf_sample[1])
            continue
        
        for _ in range(augment_ratio):
            # 모든 causal 단어를 mask, 모든 non-causal 단어를 mask
            if sampling_ratio is None:
                causal_masked_tokens = [tokens[i] if causal_mask[i] == 0 else tokenizer.mask_token_id for i in range(len(tokens))]
                noncausal_masked_tokens = [tokens[i] if causal_mask[i] == 1 else tokenizer.mask_token_id for i in range(len(tokens))]

            # sampling_ratio 갯수 (int) 만큼의 단어를 mask
            elif type(sampling_ratio) == int:
                causal_indices = np.where(np.array(causal_mask) == 1)[0]
                noncausal_indices = np.where(np.array(causal_mask) == 0)[0]

                # print(causal_indices)

                causal_mask_indices = np.random.choice(causal_indices, sampling_ratio)                    
                try:
                    noncausal_mask_indices = np.random.choice(noncausal_indices, max(1, min(sampling_ratio, len(noncausal_indices))))
                    #noncausal_mask_indices = np.random.choice(noncausal_indices, 1)
                except:
                    noncausal_mask_indices = np.random.choice(causal_indices, sampling_ratio)
                    error_cnt += 1
                
                causal_masked_tokens = [tokens[i] if i not in causal_mask_indices else tokenizer.mask_token_id for i in range(len(tokens))]
                noncausal_masked_tokens = [tokens[i] if i not in noncausal_mask_indices else tokenizer.mask_token_id for i in range(len(tokens))]
            
            # sampling_ratio 비율 (%) 만큼의 단어를 mask
            else:
                pass
            
            ### EDIT FOR NLI: REMOVE SEP TOKEN & SET LABEL ###
            causal_masked_sample = tokenizer.decode(causal_masked_tokens[:-1])
            noncausal_masked_sample = tokenizer.decode(noncausal_masked_tokens[:-1])
            
            _, labels = torch.max(batch['labels'], dim=1)
            if labels[0] == 0: label = 'contradiction'
            elif labels[0] == 1: label = 'entailment'
            else: label = 'neutral'
            triplets.append((label, orig_sample, causal_masked_sample, noncausal_masked_sample, False, 0))
    print(f"Error Cnt: {error_cnt}")    
    return triplets, 0



# Compute Propensity-based Causal Mask

In [17]:
def _TVD(orig_logits, cf_logits):
    return 0.5 * torch.cdist(orig_logits.unsqueeze(0), cf_logits.unsqueeze(0), p=1).squeeze().item()

def mask_uniform_propensity_causal_words(tokens, batch, importances, topk=1):
    THRESHOLD = 0.1
    uniform_dist = torch.ones(1, 2) / 2.0
    uniform_dist = uniform_dist.to("cuda")
    causal_mask = [0 for _ in range(len(tokens))]
    topk_indices = np.argsort(importances)[::-1]
    err_flag = False
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    token_type_ids = batch['token_type_ids'].to(device)
    orig_outputs = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    _, orig_prediction = torch.max(orig_outputs[0], dim=1)
    best_tvd = 0.
    best_idx = -1
    for i, topk_idx in enumerate(topk_indices):
        # For excepting [SEP] token ...
        if importances[topk_idx] == 0:
            continue
        masked_input_ids = input_ids.clone()
        masked_input_ids[0][topk_idx + 1] = tokenizer.mask_token_id
        masked_outputs = model(masked_input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        _, masked_prediction = torch.max(masked_outputs[0], dim=1)
        orig_tvd = _TVD(torch.softmax(orig_outputs[0], dim=-1), uniform_dist)
        masked_tvd = _TVD(torch.softmax(masked_outputs[0], dim=-1), uniform_dist)
        
        # Use maximum value
        if orig_tvd > masked_tvd and abs(orig_tvd - masked_tvd) > best_tvd:
            causal_mask[best_idx] = 0
            causal_mask[topk_idx] = 1
            best_tvd = abs(orig_tvd - masked_tvd)
            best_idx = topk_idx
        else:
            continue
        """
        # Use Gradient Order
        if orig_tvd > masked_tvd:
            causal_mask[topk_idx] = 1
            break
        else:
            continue
        """    
        
    if 1 not in causal_mask:
        causal_mask[topk_indices[0]] = 1
        err_flag = True
    return causal_mask, err_flag, best_tvd

def mask_softed_propensity_causal_words(tokens, batch, importances, topk=1):
    causal_mask = [0 for _ in range(len(tokens))]
    topk_indices = np.argsort(importances)[::-1]
    err_flag = False
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    token_type_ids = batch['token_type_ids'].to(device)
    orig_outputs = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    _, orig_prediction = torch.max(orig_outputs[0], dim=1)
    best_tvd = 0.
    best_idx = -1
    for i, topk_idx in enumerate(topk_indices):
        # For excepting [SEP] token ...
        if importances[topk_idx] == 0:
            continue
        masked_input_ids = input_ids.clone()
        masked_input_ids[0][topk_idx + 1] = tokenizer.mask_token_id
        masked_outputs = model(masked_input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        _, masked_prediction = torch.max(masked_outputs[0], dim=1)
        tvd_value = _TVD(torch.softmax(orig_outputs[0], dim=-1), torch.softmax(masked_outputs[0], dim=-1))
        
        # Use Maximum  Value
        if tvd_value > best_tvd:
            causal_mask[best_idx] = 0
            causal_mask[topk_idx] = 1
            best_tvd = tvd_value
            best_idx = topk_idx
        else:
            continue
            
        """
        # Use Gradeint Order
        if orig_tvd > MIN_FLIPPED:
            causal_mask[topk_idx] = 1
            break
        else:
            continue
        """
        
    if 1 not in causal_mask:
        causal_mask[topk_indices[0]] = 1
        err_flag = True
    return causal_mask, err_flag, best_tvd

def mask_propensity_causal_words(tokens, batch, importances, topk=1):
    causal_mask = [0 for _ in range(len(tokens))]
    topk_indices = np.argsort(importances)[::-1]
    err_flag = False
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    token_type_ids = batch['token_type_ids'].to(device)
    orig_outputs = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    _, orig_prediction = torch.max(orig_outputs[0], dim=1)
    for i, topk_idx in enumerate(topk_indices):
        # For excepting [SEP] token ...
        if importances[topk_idx] == 0:
            continue
        masked_input_ids = input_ids.clone()
        masked_input_ids[0][topk_idx + 1] = tokenizer.mask_token_id
        masked_outputs = model(masked_input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        _, masked_prediction = torch.max(masked_outputs[0], dim=1)
        if orig_prediction != masked_prediction:
            causal_mask[topk_idx] = 1
            break
    if 1 not in causal_mask:
        causal_mask[topk_indices[0]] = 1
        err_flag = True
    return causal_mask, err_flag, 0

def build_propensity_causal_mask_with_precomputed(data_loader, all_importances, sampling_ratio, augment_ratio):
    triplets = []
    error_cnt = 0
    no_flip_cnt = 0
    no_flip_idx = []
    for importances, batch in tqdm(zip(all_importances, data_loader)):
        tokens = torch.tensor([x for x in batch['input_ids'][0][1:] if x not in [tokenizer.pad_token_id]])
        assert tokens.size() == importances.size()
        if len(tokens) == 0:
            print(batch['input_ids'][0])
            triplets.append((label, orig_sample, orig_sample, orig_sample, True, 0))
            continue
        orig_sample = tokenizer.decode(tokens[:-1])
        #causal_mask, err_flag, maximum_score = mask_propensity_causal_words(tokens.cpu().numpy(), batch, importances.cpu().numpy(), topk=sampling_ratio)
        #causal_mask, err_flag, maximum_score = mask_softed_propensity_causal_words(tokens.cpu().numpy(), batch, importances.cpu().numpy(), topk=sampling_ratio)
        causal_mask, err_flag, maximum_score = mask_uniform_propensity_causal_words(tokens.cpu().numpy(), batch, importances.cpu().numpy(), topk=sampling_ratio)
        no_flip_idx.append(err_flag)
        if err_flag:
            no_flip_cnt += 1
        # visualize(tokens, causal_mask)
        # print(causal_mask)
        
        if 1 not in causal_mask:
            print(tokens)
            triplets.append((label, orig_sample, orig_sample, orig_sample, err_flag, maximum_score))
            continue
        
        for _ in range(augment_ratio):
            # 모든 causal 단어를 mask, 모든 non-causal 단어를 mask
            if sampling_ratio is None:
                causal_masked_tokens = [tokens[i] if causal_mask[i] == 0 else tokenizer.mask_token_id for i in range(len(tokens))]
                noncausal_masked_tokens = [tokens[i] if causal_mask[i] == 1 else tokenizer.mask_token_id for i in range(len(tokens))]

            # sampling_ratio 갯수 (int) 만큼의 단어를 mask
            elif type(sampling_ratio) == int:
                causal_indices = np.where(np.array(causal_mask) == 1)[0]
                noncausal_indices = np.where(np.array(causal_mask) == 0)[0]

                # print(causal_indices)

                causal_mask_indices = np.random.choice(causal_indices, sampling_ratio)                    
                try:
                    noncausal_mask_indices = np.random.choice(noncausal_indices, max(1, min(sampling_ratio, len(noncausal_indices))))
                    #noncausal_mask_indices = np.random.choice(noncausal_indices, 1)
                except:
                    noncausal_mask_indices = np.random.choice(causal_indices, sampling_ratio)
                    error_cnt += 1

                causal_masked_tokens = [tokens[i] if i not in causal_mask_indices else tokenizer.mask_token_id for i in range(len(tokens))]
                noncausal_masked_tokens = [tokens[i] if i not in noncausal_mask_indices else tokenizer.mask_token_id for i in range(len(tokens))]
            
            # sampling_ratio 비율 (%) 만큼의 단어를 mask
            else:
                pass
                
            ### EDIT FOR NLI: REMOVE SEP TOKEN & SET LABEL ###
            causal_masked_sample = tokenizer.decode(causal_masked_tokens[:-1])
            noncausal_masked_sample = tokenizer.decode(noncausal_masked_tokens[:-1])
            
            _, labels = torch.max(batch['labels'], dim=1)
            if labels[0] == 0: label = 'contradiction'
            elif labels[0] == 1: label = 'entailment'
            else: label = 'neutral'
            triplets.append((label, orig_sample, causal_masked_sample, noncausal_masked_sample, err_flag, maximum_score))
    print(f"Error Cnt: {error_cnt}")    
    print(f"No Flip Cnt: {no_flip_cnt}")    
    return triplets, no_flip_idx


# Compute LM-based Causal Mask

In [25]:
from transformers import BertForMaskedLM

mlm_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
mlm_model = mlm_model.to(device)
mlm_model.eval()
TOPK_NUM = 4

def mask_efficient_LM_dropout_causal_words(tokens, batch, importances, topk=1):
    dropout = torch.nn.Dropout(0.5)
    causal_mask = [0 for _ in range(len(tokens))]
    topk_indices = np.argsort(importances)[::-1]
    err_flag = False
    find_flag = False
    
    #Skip neutral sample
    if torch.max(batch['labels'].squeeze()).item() == 2:
        causal_mask[topk_indices[0]] = 1
        err_flag = True
        return causal_mask, err_flag, 0
    
    input_ids = batch['input_ids'].squeeze().repeat((TOPK_NUM, )).reshape(TOPK_NUM, -1).to(device)
    attention_mask = batch['attention_mask'].expand(TOPK_NUM, -1).to(device)
    token_type_ids = batch['token_type_ids'].expand(TOPK_NUM, -1).to(device)
    
    masked_input_ids = batch['input_ids'].squeeze().repeat((len(tokens), )).reshape(len(tokens), -1).to(device)
    masked_attention_mask = batch['attention_mask'].squeeze().repeat((len(tokens), )).reshape(len(tokens), -1).to(device)
    masked_token_type_ids = batch['token_type_ids'].expand(len(tokens), -1).to(device)
    fake_labels = torch.ones((len(tokens),))
    masked_train = IMDbDataset({
        'input_ids': masked_input_ids, 
        'attention_mask': masked_attention_mask, 
        'token_type_ids': masked_token_type_ids, 
        'topk_indices': topk_indices}, fake_labels)
    masked_train_loader = DataLoader(masked_train, batch_size=4, shuffle=False)
    logits = []
    for masked_batch in masked_train_loader:
        masked_input_ids = masked_batch['input_ids'].to(device)
        masked_attention_mask = masked_batch['attention_mask'].to(device)
        masked_token_type_ids = masked_batch['token_type_ids'].to(device)
        topk_index = masked_batch['topk_indices'].to(device)
        masked_input_embeds = mlm_model.bert.embeddings.word_embeddings(masked_input_ids)
        for mi_idx, topk_idx in zip(range(masked_input_embeds.size(0)), topk_index):
            masked_input_embeds[mi_idx][topk_idx + 1] = dropout(masked_input_embeds[mi_idx][topk_idx + 1])
            if masked_token_type_ids[mi_idx][topk_idx + 1] == 0:
                masked_attention_mask[mi_idx] = ~masked_token_type_ids[mi_idx] & masked_attention_mask[mi_idx]
            else:
                masked_attention_mask[mi_idx] = masked_token_type_ids[mi_idx] & masked_attention_mask[mi_idx]
                masked_attention_mask[mi_idx][0] = 1
        with torch.no_grad():
            outputs = mlm_model(attention_mask=masked_attention_mask, inputs_embeds=masked_input_embeds)
            predictions = outputs[0]
            #logits.append(predictions.detach().cpu())
        
        topk_logits = torch.topk(predictions, TOPK_NUM, dim=-1)[1]
        mask_candidates = [topk_logit[topk_idx + 1] for topk_idx, topk_logit in zip(topk_index, topk_logits)]
        
        for topk_idx, mask_candidate in zip(topk_index, mask_candidates):
            # For excepting [SEP] token ...
            if importances[topk_idx] == 0:
                continue
            recon_input_ids = input_ids.clone()
            for i, mc in enumerate(mask_candidate):
                recon_input_ids[i][topk_idx + 1] = mc
             
            with torch.no_grad():
                recon_outputs = model(recon_input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
                _, recon_prediction = torch.max(recon_outputs[0], dim=1)
                        
            # IF prediction is changed:
            if len(torch.unique(recon_prediction)) != 1:
                """
                print(tokenizer.decode(recon_input_ids[0], skip_special_tokens=True))
                print(tokenizer.decode(recon_input_ids[1], skip_special_tokens=True))
                print(tokenizer.decode(recon_input_ids[2], skip_special_tokens=True))
                print(tokenizer.decode(recon_input_ids[3], skip_special_tokens=True))
                print(recon_prediction)
                """
                causal_mask[topk_idx] = 1
                find_flag = True
                break
            
            """
            # IF prediction has entail and contradiction
            if 0 in recon_prediction and 1 in recon_prediction:
                causal_mask[topk_idx] = 1
                find_flag = True
                break
            """    
        if find_flag:
            break
    
    if 1 not in causal_mask:
        causal_mask[topk_indices[0]] = 1
        err_flag = True
    
    return causal_mask, err_flag, 0

"""
def mask_LM_dropout_causal_words(tokens, batch, importances, topk=1):
    dropout = torch.nn.Dropout(0.5)
    causal_mask = [0 for _ in range(len(tokens))]
    topk_indices = np.argsort(importances)[::-1]
    err_flag = False
    
    masked_input_ids = batch['input_ids'].squeeze().repeat((len(tokens), )).reshape(len(tokens), -1).to(device)
    masked_attention_mask = batch['attention_mask'].expand(len(tokens), -1).to(device)
    masked_token_type_ids = batch['token_type_ids'].expand(len(tokens), -1).to(device)
    fake_labels = torch.ones((len(tokens),))
    masked_train = IMDbDataset({'input_ids': masked_input_ids, 'attention_mask': masked_attention_mask, 'token_type_ids': masked_token_type_ids, 'topk_indices': topk_indices}, fake_labels)
    masked_train_loader = DataLoader(masked_train, batch_size=8, shuffle=False)
    logits = []
    for masked_batch in masked_train_loader:
        masked_input_ids = masked_batch['input_ids'].to(device)
        masked_attention_mask = masked_batch['attention_mask'].to(device)
        masked_token_type_ids = masked_batch['token_type_ids'].to(device)
        topk_index = masked_batch['topk_indices'].to(device)
        masked_input_embeds = mlm_model.bert.embeddings.word_embeddings(masked_input_ids)
        for mie, ti in zip(masked_input_embeds, topk_index):
            mie[ti + 1] = dropout(mie[ti + 1])
        with torch.no_grad():
            outputs = mlm_model(attention_mask=masked_attention_mask, token_type_ids=masked_token_type_ids, inputs_embeds=masked_input_embeds)
            predictions = outputs[0]
            logits.append(predictions)
            
    logits = torch.cat(logits, dim=0)
    topk_logits = torch.topk(logits, TOPK_NUM, dim=-1)[1]
    mask_candidates = [topk_logit[topk_idx + 1] for topk_idx, topk_logit in zip(topk_indices, topk_logits)]
    
    input_ids = batch['input_ids'].squeeze().repeat((TOPK_NUM, )).reshape(TOPK_NUM, -1).to(device)
    attention_mask = batch['attention_mask'].expand(TOPK_NUM, -1).to(device)
    token_type_ids = batch['token_type_ids'].expand(TOPK_NUM, -1).to(device)
    
    for topk_idx, mask_candidate in zip(topk_indices, mask_candidates):
        recon_input_ids = input_ids.clone()
        for i, mc in enumerate(mask_candidate):
            recon_input_ids[i][topk_idx + 1] = mc
             
        with torch.no_grad():
            recon_outputs = model(recon_input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
            _, recon_prediction = torch.max(recon_outputs[0], dim=1)
        
        # IF prediction is changed:
        if len(torch.unique(recon_prediction)) != 1:
            causal_mask[topk_idx] = 1
            break
    
    if 1 not in causal_mask:
        causal_mask[topk_indices[0]] = 1
        err_flag = True
    
    return causal_mask, err_flag, 0

def mask_LM_causal_words(tokens, batch, importances, topk=1):
    causal_mask = [0 for _ in range(len(tokens))]
    topk_indices = np.argsort(importances)[::-1]
    err_flag = False

    masked_input_ids = batch['input_ids'].squeeze().repeat((len(tokens), )).reshape(len(tokens), -1).to(device)
    masked_attention_mask = batch['attention_mask'].expand(len(tokens), -1).to(device)
    for i, topk_idx in enumerate(topk_indices):
        masked_input_ids[i][topk_idx + 1] = tokenizer.mask_token_id
    fake_labels = torch.ones((len(tokens),))
    masked_train = IMDbDataset({'input_ids': masked_input_ids, 'attention_mask': masked_attention_mask}, fake_labels)
    masked_train_loader = DataLoader(masked_train, batch_size=8, shuffle=False)
    logits = []
    for masked_batch in masked_train_loader:
        masked_input_ids = masked_batch['input_ids'].to(device)
        masked_attention_mask = masked_batch['attention_mask'].to(device)
        with torch.no_grad():
            outputs = mlm_model(masked_input_ids, attention_mask=masked_attention_mask)
            predictions = outputs[0]
            logits.append(predictions)
    logits = torch.cat(logits, dim=0)
    topk_logits = torch.topk(logits, TOPK_NUM, dim=-1)[1]
    mask_candidates = [topk_logit[topk_idx + 1] for topk_idx, topk_logit in zip(topk_indices, topk_logits)]
    
    input_ids = batch['input_ids'].squeeze().repeat((TOPK_NUM, )).reshape(TOPK_NUM, -1).to(device)
    attention_mask = batch['attention_mask'].expand(TOPK_NUM, -1).to(device)
    for topk_idx, mask_candidate in zip(topk_indices, mask_candidates):
        recon_input_ids = input_ids.clone()
        for i, mc in enumerate(mask_candidate):
            recon_input_ids[i][topk_idx + 1] = mc
        
        with torch.no_grad():
            recon_outputs = model(recon_input_ids, attention_mask=attention_mask)
            _, recon_prediction = torch.max(recon_outputs[0], dim=1)
        
        # IF prediction is changed:
        if len(torch.unique(recon_prediction)) != 1:
            causal_mask[topk_idx] = 1
            break
    
    if 1 not in causal_mask:
        causal_mask[topk_indices[0]] = 1
        err_flag = True
    
    return causal_mask, err_flag, 0
"""

def build_LM_causal_mask_with_precomputed(data_loader, all_importances, sampling_ratio, augment_ratio):
    triplets = []
    error_cnt = 0
    no_flip_cnt = 0
    no_flip_idx = []
    for importances, batch in tqdm(zip(all_importances, data_loader)):
        tokens = torch.tensor([x for x in batch['input_ids'][0][1:] if x not in [tokenizer.pad_token_id]])
        assert tokens.size() == importances.size()

        orig_sample = tokenizer.decode(tokens[:-1])
        causal_mask, err_flag, maximum_score = mask_efficient_LM_dropout_causal_words(tokens.cpu().numpy(), batch, importances.cpu().numpy(), topk=sampling_ratio)
        no_flip_idx.append(err_flag)
        if err_flag:
            no_flip_cnt += 1
        # visualize(tokens, causal_mask)
        # print(causal_mask)
        
        if 1 not in causal_mask:
            print(tokens)
            triplets.append((label, orig_sample, orig_sample, orig_sample, err_flag, maximum_score))
            continue
        
        for _ in range(augment_ratio):
            # 모든 causal 단어를 mask, 모든 non-causal 단어를 mask
            if sampling_ratio is None:
                causal_masked_tokens = [tokens[i] if causal_mask[i] == 0 else tokenizer.mask_token_id for i in range(len(tokens))]
                noncausal_masked_tokens = [tokens[i] if causal_mask[i] == 1 else tokenizer.mask_token_id for i in range(len(tokens))]

            # sampling_ratio 갯수 (int) 만큼의 단어를 mask
            elif type(sampling_ratio) == int:
                causal_indices = np.where(np.array(causal_mask) == 1)[0]
                noncausal_indices = np.where(np.array(causal_mask) == 0)[0]

                # print(causal_indices)

                causal_mask_indices = np.random.choice(causal_indices, sampling_ratio)                    
                try:
                    noncausal_mask_indices = np.random.choice(noncausal_indices, max(1, min(sampling_ratio, len(noncausal_indices))))
                    #noncausal_mask_indices = np.random.choice(noncausal_indices, 1)
                except:
                    noncausal_mask_indices = np.random.choice(causal_indices, sampling_ratio)
                    error_cnt += 1

                causal_masked_tokens = [tokens[i] if i not in causal_mask_indices else tokenizer.mask_token_id for i in range(len(tokens))]
                noncausal_masked_tokens = [tokens[i] if i not in noncausal_mask_indices else tokenizer.mask_token_id for i in range(len(tokens))]
            
            # sampling_ratio 비율 (%) 만큼의 단어를 mask
            else:
                pass
                
            ### EDIT FOR NLI: REMOVE SEP TOKEN & SET LABEL ###
            causal_masked_sample = tokenizer.decode(causal_masked_tokens[:-1])
            noncausal_masked_sample = tokenizer.decode(noncausal_masked_tokens[:-1])
            
            _, labels = torch.max(batch['labels'], dim=1)
            if labels[0] == 0: label = 'contradiction'
            elif labels[0] == 1: label = 'entailment'
            else: label = 'neutral'
            triplets.append((label, orig_sample, causal_masked_sample, noncausal_masked_sample, err_flag, maximum_score))
    print(f"Error Cnt: {error_cnt}")    
    print(f"No Flip Cnt: {no_flip_cnt}")    
    return triplets, no_flip_idx

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Compute OR Load Gradient Importance

In [19]:
if not os.path.exists(PICKLE_PATH):
    os.makedirs(PICKLE_PATH)

#all_importance = compute_importances(train_loader, get_gradient_norms)
#with open("dataset/SST-2/cf_augmented_examples/gradient_importance.pickle", 'wb') as f:
#    pickle.dump(all_importance, f)

if os.path.exists(os.path.join(PICKLE_PATH, "gradient_importance.pickle")):
    with open(os.path.join(PICKLE_PATH, "gradient_importance.pickle"), 'rb') as f:
        all_importance = pickle.load(f)
else:
    all_importance = compute_importances(train_loader, get_gradient_norms)
    with open(os.path.join(PICKLE_PATH, "gradient_importance.pickle"), 'wb') as f:
        pickle.dump(all_importance, f)

# Get Average Importance

In [20]:
averaged_all_importance = compute_average_importance(train_loader, all_importance)
with open(os.path.join(PICKLE_PATH, "gradient_averaged_importance.pickle"), 'wb') as f:
        pickle.dump(averaged_all_importance, f)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




# Generate Triplets

In [26]:
sampling_ratio = 1
augment_ratio = 1
#triplets_train, no_flip_idx_train = build_propensity_causal_mask_with_precomputed(train_loader, averaged_all_importance, sampling_ratio=sampling_ratio, augment_ratio=augment_ratio)
triplets_train, no_flip_idx_train = build_LM_causal_mask_with_precomputed(train_loader, averaged_all_importance, sampling_ratio=sampling_ratio, augment_ratio=augment_ratio)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

  import sys
  



Error Cnt: 0
No Flip Cnt: 746


In [22]:
triplets_train[5]

('entailment',
 'a woman in a red polka - dot dress sings into a mic. [SEP] a woman is wearing a dress.',
 'a [MASK] in a red polka - dot dress sings into a mic. [SEP] a woman is wearing a dress.',
 'a woman in a red polka - dot dress sings into a mic. [SEP] a [MASK] is wearing a dress.',
 False,
 0)

In [28]:
with open(os.path.join(PICKLE_PATH, "triplets_automated_averaged_gradient_LM_dropout_05_flip_sep_sampling{}_augmenting{}_train.pickle".format(sampling_ratio, augment_ratio)), "wb") as fp:
    pickle.dump(triplets_train, fp)

# Cut with Empirically  Setted Threshold

In [None]:
THRESHOLD = 0.0001
threshold_cnt = 0
threshold_triplets_train = []
for t in triplets_train:
    t = list(t)
    if t[-1] < THRESHOLD:
        t[-2] = True
        threshold_cnt += 1
    t = tuple(t)
    threshold_triplets_train.append(t)
print(threshold_cnt)

In [None]:
with open(os.path.join(PICKLE_PATH, "triplets_automated_averaged_gradient_propensity_TVD_uniform_thres_00001_sampling{}_augmenting{}_train.pickle".format(sampling_ratio, augment_ratio)), "wb") as fp:
    pickle.dump(threshold_triplets_train, fp)

In [None]:
###For Generating Optionally

#for sampling_ratio in [1, 2, 3, 4, 5]:
for sampling_ratio in [1]:
    for augment_ratio in [1]:
        triplets_train, no_flip_idx_train = build_propensity_causal_mask_with_precomputed(train_loader, averaged_all_importance, sampling_ratio=sampling_ratio, augment_ratio=augment_ratio)
        with open(os.path.join(PICKLE_PATH, "triplets_automated_averaged_gradient_propensity_sampling{}_augmenting{}_train.pickle".format(sampling_ratio, augment_ratio)), "wb") as fp:
            pickle.dump(triplets_train, fp)

In [None]:
MIN_FLIPPED = 0.007849693298339844 #

In [None]:
c = sorted(unflipped_TVD_logit)
print(min(flipped_TVD_logit))
print(c[int(len(c) * 0.90)])

# Qualitative Test

In [None]:
#SHELL FOR QUALITATIVE: WILL BE REMOVED\
sampling_ratio = 1
augment_ratio = 1
triplets_train, no_flip_idx_train = build_propensity_causal_mask_with_precomputed(train_loader, averaged_all_importance, sampling_ratio=None, augment_ratio=augment_ratio)

In [None]:
qual_triplets = [tt for tt in triplets_train if tt[2].count("[MASK]") > 1]

In [None]:
qual_triplets[37]

In [None]:
for i, tt in enumerate(triplets_train):
    if tt[1] == 'commercialism all in the same movie... without neglecting character development for even one minute':
        print(i)

In [None]:
with open(os.path.join(PICKLE_PATH, "triplets_automated_averaged_gradient_propensity_sampling1_augmenting1_train.pickle".format(sampling_ratio, augment_ratio)), "rb") as fp:
            data_b = pickle.load(fp)
        
with open(os.path.join(PICKLE_PATH, "triplets_automated_averaged_gradient_sampling1_augmenting1_train.pickle".format(sampling_ratio, augment_ratio)), "rb") as fp:
            data_a = pickle.load(fp)

In [None]:
cnt = 0
diff_idxes = []
for i, (a, b) in enumerate(zip(data_a, data_b)):
    if a[2] != b[2]:
        diff_idxes.append(i)
        cnt += 1

In [None]:
tmp_i = 3
print("[ORIG]:    " + data_a[diff_idxes[tmp_i]][1])
print()
print("[GRAD]:    " + data_a[diff_idxes[tmp_i]][2])
print()
print("[PROP]:    " + data_b[diff_idxes[tmp_i]][2])

In [None]:
torch.set_printoptions(precision=4,sci_mode=False)
print(tokenizer.tokenize('commercialism all in the same movie... without neglecting character development for even one minute'))
print(averaged_all_importance[1590])

In [None]:
data_a[1590]

In [None]:
9719 / - (53349 - 67349)