In [143]:
import json
import os
import pickle
import torch
from torch.utils.data import DataLoader
from pathlib import Path
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from transformers import BertTokenizer, BertForSequenceClassification, AdamW

In [12]:
DATASET_PATH = "dataset/aclImdb"
REFORMED_DATASET_PATH = "dataset/reform_aclImdb"
OUTPUT_PATH = "checkpoints/not_augmented_output_scheduling_warmup"
REPS_PATH = "reps"
if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

TRAIN_SPLIT = "train"
TEST_SPLIT = "test"

In [3]:
class IMDbDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [5]:
# Load dataset
# with open(os.path.join(REFORMED_DATASET_PATH, "train.json")) as f:
#     train = json.load(f) 
# with open(os.path.join(REFORMED_DATASET_PATH, "valid.json")) as f:
#     val = json.load(f)

In [6]:
# train_texts = [d['text'] for d in train]
# train_labels = [d['label'] for d in train]
# val_texts = [d['text'] for d in val]
# val_labels = [d['label'] for d in val]

In [8]:
# Encode dataset
# train_encodings = tokenizer(train_texts, truncation=True, padding=True)
# val_encodings = tokenizer(val_texts, truncation=True, padding=True)

In [9]:
# make dataset class
# train_dataset = IMDbDataset(train_encodings, train_labels)
# val_dataset = IMDbDataset(val_encodings, val_labels)

In [10]:
# train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
# val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [None]:
# Define tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [13]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = BertForSequenceClassification.from_pretrained(os.path.join(OUTPUT_PATH, 'epoch_2'))
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

## Gradient-based Masking

In [87]:
# Load dataset
paired = pickle.load(open("./dataset/cf_augmented_examples/triplets_sampling4_augmenting1_train.pickle", 'rb'))

In [94]:
train_texts = [d[1] for d in paired]
train_labels = [[1., 0.] if d[0] == 'Negative' else [0., 1.] for d in paired]

# Encode dataset
train_encodings = tokenizer(train_texts, truncation=True, padding=True)

# make dataset class
train_dataset = IMDbDataset(train_encodings, train_labels)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)

In [98]:
# Compute gradient at BERT's position_embeddings (discard [cls] and [sep]/[pad])
# Only works for batch_size = 1    
def get_gradient_norms(batch):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    # For CrossEntropy Loss
    _, labels = torch.max(labels, dim=1)

    outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
    loss = outputs[0]
    # print(loss)
    loss.backward(retain_graph=True)
    torch.cuda.empty_cache()

    importances = torch.tensor([]).to(device)
    for pos_index, token_index in zip(range(1, len(input_ids[0])), input_ids[0][1:]):
        if token_index == tokenizer.sep_token_id:
            break

        importance = torch.norm(model._modules['bert']._modules['embeddings']._modules['position_embeddings'].weight.grad[pos_index], 2).float().detach()
        importances = torch.cat((importances, importance.unsqueeze(0)), dim=-1)

    # importances_list.append(importances)
    model._modules['bert']._modules['embeddings']._modules['position_embeddings'].weight.grad = None

    # return importances_list
    return importances

In [149]:
import random
import numpy as np
import seaborn as sn
import matplotlib.pyplot as plt
        
def visualize(words, masks):
    fig, ax = plt.subplots(figsize=(len(words), 1))
    plt.rc('xtick', labelsize=16)
    heatmap = sn.heatmap([masks], xticklabels=words, yticklabels=False, square=True, \
                         linewidths=0.1, cmap='coolwarm', center=0.5, vmin=0, vmax=1)
    plt.xticks(rotation=45)
    plt.show()
    
def mask_causal_words(tokens, importances, topk=1):
    causal_mask = [0 for _ in range(len(tokens))]
    topk_indices = np.argsort(importances)[::-1][:topk]
    for topk_idx in topk_indices:
#         print(topk_idx)
#         print(tokens[topk_idx])
        causal_mask[topk_idx] = 1
    
    return causal_mask

def build_causal_mask(data_loader, sampling_ratio, augment_ratio):
    triplets = []
    for batch in tqdm(data_loader):
        importances = get_gradient_norms(batch)
        
        tokens = torch.tensor([x for x in batch['input_ids'][0][1:] if x not in [tokenizer.sep_token_id, tokenizer.pad_token_id]])
        assert tokens.size() == importances.size()
        
        orig_sample = tokenizer.decode(tokens)
        causal_mask = mask_causal_words(tokens.cpu().numpy(), importances.cpu().numpy(), topk=sampling_ratio)
        # visualize(tokens, causal_mask)
        # print(causal_mask)
        
        if 1 not in causal_mask:
            # print(orig_sample[1], cf_sample[1])
            continue
        
        for _ in range(augment_ratio):
            # 모든 causal 단어를 mask, 모든 non-causal 단어를 mask
            if sampling_ratio is None:
                causal_masked_tokens = [tokens[i] if causal_mask[i] == 0 else tokenizer.mask_token_id for i in range(len(tokens))]
                noncausal_masked_tokens = [tokens[i] if causal_mask[i] == 1 else tokenizer.mask_token_id for i in range(len(tokens))]

            # sampling_ratio 갯수 (int) 만큼의 단어를 mask
            elif type(sampling_ratio) == int:
                causal_indices = np.where(np.array(causal_mask) == 1)[0]
                noncausal_indices = np.where(np.array(causal_mask) == 0)[0]

                # print(causal_indices)

                causal_mask_indices = np.random.choice(causal_indices, sampling_ratio)                    
                noncausal_mask_indices = np.random.choice(noncausal_indices, sampling_ratio)

                causal_masked_tokens = [tokens[i] if i not in causal_mask_indices else tokenizer.mask_token_id for i in range(len(tokens))]
                noncausal_masked_tokens = [tokens[i] if i not in noncausal_mask_indices else tokenizer.mask_token_id for i in range(len(tokens))]
            
            # sampling_ratio 비율 (%) 만큼의 단어를 mask
            else:
                pass
                
            causal_masked_sample = tokenizer.decode(causal_masked_tokens)
            noncausal_masked_sample = tokenizer.decode(noncausal_masked_tokens)
            
            _, labels = torch.max(batch['labels'], dim=1)
            if labels[0] == 0: label = 'Negative'
            elif labels[0] == 1: label = 'Positive'
            triplets.append((label, orig_sample, causal_masked_sample, noncausal_masked_sample))
        
    return triplets

import pickle

# sampling_ratio = 2
# augment_ratio = 1
for sampling_ratio in [1, 2, 3]:
    for augment_ratio in [1]:
        triplets_train = build_causal_mask(train_loader, sampling_ratio=sampling_ratio, augment_ratio=augment_ratio)
        with open("dataset/cf_augmented_examples/triplets_automated_gradient_sampling{}_augmenting{}_train.pickle".format(sampling_ratio, augment_ratio), "wb") as fp:
            pickle.dump(triplets_train, fp)

HBox(children=(IntProgress(value=0, max=1700), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1700), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1700), HTML(value='')))




In [146]:
triplets_train[-1]

('Positive',
 'p',
 "zp is deeply related to that youth dream represented by the hippie movement. the college debate in the beginning of the movie states the cultural situation that gives birth to that movement. the explosion that daria imagines, represents the fall of all social structures and therefore the development of all that huge transformation that society is suffering through and finally mark's death anticipates the end that a sees for the movement itself. the film will be more easily understood if we go back to that time in life. during the 60'and 70 ', young people were the driving force for the profound explorations for change. one of the more significant changes intended was to bring sexuality out of the closet, and i think the scenes in the desert do not represent an orgy but the sexual relationship that men and women in absolute freedom would perform in the hipotetic situation where there would be nobody to hide from. i watched the scene where the couples would throw san