In [1]:
import torch
from transformers import BertForSequenceClassification, TrainingArguments, Trainer
from transformers import AutoModel, AutoTokenizer
from torch.nn import Identity
import torch.nn as nn
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer
import pickle
import dill
from textattack.models.wrappers.huggingface_model_wrapper import HuggingFaceModelWrapper
from transformers import PreTrainedModel
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.distributions.normal import Normal
import transformers
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.utils.data import Dataset
from tqdm import tqdm
from textattack import Attacker, AttackArgs
from textattack import Attacker
from textattack.attack_recipes import BAEGarg2019
from textattack.datasets import HuggingFaceDataset
from textattack.attack_results import AttackResult
from textattack.metrics.attack_metrics import (
    AttackQueries,
    AttackSuccessRate,
    WordsPerturbed,
)


2023-07-16 16:18:43.421605: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-07-16 16:18:43.504644: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-07-16 16:18:43.934211: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-07-16 16:18:43.934247: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] 

In [2]:
class TransformerVIB(nn.Module):
    """
    Classifier with stochastic layer and KL regularization
    """
    def __init__(self, hidden_size, output_size, device):
        super(TransformerVIB, self).__init__()
        self.device = device
        self.description = 'Vanilla IB VAE as per the paper'
        self.hidden_size = hidden_size
        self.k = hidden_size // 2
        self.output_size = output_size
        self.train_loss = []
        self.test_loss = []

        self.encoder = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )
        self.classifier = nn.Linear(self.k, output_size)
        # These are cheats to make 'drill' save everythung we need in one pickle
        self.softplus = F.softplus
        self.normal = torch.normal
        self.Normal = Normal

        # Xavier initialization
        for _, module in self._modules.items():
            if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
                        nn.init.xavier_uniform_(module.weight, gain=nn.init.calculate_gain('relu'))
                        module.bias.data.zero_()
                        continue
            for layer in module:
                if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                            nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain('relu'))
                            layer.bias.data.zero_()
        

    def reparametrize(self, mu, std, device):
        """
        Performs reparameterization trick z = mu + epsilon * std
        Where epsilon~N(0,1)
        """
        mu = mu.expand(1, *mu.size())
        std = std.expand(1, *std.size())
        eps = self.normal(0, 1, size=std.size()).to(device)
        return mu + eps * std        
        
    def forward(self, x):
        z_params = self.encoder(x)
        mu = z_params[:, :self.k]
#         std = torch.nn.functional.softplus(z_params[:, self.k:] - 1, beta=1)        
        std = self.softplus(z_params[:, self.k:] - 1, beta=1)        
        if self.training:
            z = self.reparametrize(mu, std, self.device)
        else:
            z = mu.clone().unsqueeze(0)
        n = self.Normal(mu, std)
        log_probs = n.log_prob(z.squeeze(0))  # These may be positive as this is a PDF
        
        logits = self.classifier(z)
        return (mu, std), log_probs, logits
   

class TransformerHybridModel(nn.Module):
    """
    Head is a pretrained model, classifier is VIB
    fc_name should be 'fc2' for inception-v3 (imagenet) and mnist-cnn, '_fc' for efficient-net (CIFAR)
    """
    def __init__(self, base_model, vib_model, device, fc_name, return_only_logits=False):
        super(TransformerHybridModel, self).__init__()
        self.device = device
        self.base_model = base_model
        setattr(self.base_model, fc_name, torch.nn.Identity())
        self.vib_model = vib_model
        self.train_loss = []
        self.test_loss = []
        self.freeze_base()
        self.return_only_logits = return_only_logits

    def set_return_only_logits(self, bool_value):
        self.return_only_logits = bool_value
    
    def freeze_base(self):
        # Freeze the weights of the inception_model
        for param in self.base_model.parameters():
            param.requires_grad = False

    def unfreeze_base(self):
        # Freeze the weights of the inception_model
        for param in self.base_model.parameters():
            param.requires_grad = True

    def forward(self, **kwargs):
        encoded = self.base_model(kwargs['input_ids']).logits # This is not really logits, only called that way cause we've changed the final layer to identity
        (mu, std), log_probs, logits = self.vib_model(encoded)
        if self.return_only_logits:
            return logits.squeeze(0)
        else:
            return ((mu, std), log_probs, logits)


class TransformerAdaptor(transformers.PreTrainedModel):
    """
    Adapts between a TransformerHybridModel to a HuggingFaceModelWrapper
    """
    def __init__(self, hybrid_model):
        super(TransformerAdaptor, self).__init__(hybrid_model.base_model.config)
        self.hybrid_model = hybrid_model
        self.SequenceClassifierOutput = SequenceClassifierOutput  # Cheat to overload drill pickle
    
    def forward(self, **kwargs):
        if self.hybrid_model.return_only_logits:
            logits = self.hybrid_model(**kwargs)
        else:
            ((_, _), _, logits) = self.hybrid_model(**kwargs)
        return self.SequenceClassifierOutput(logits=logits[0])


In [3]:
# Download the model
model = BertForSequenceClassification.from_pretrained('textattack/bert-base-uncased-yelp-polarity')

In [169]:
with open('/D/models/pretrained/pretrained_bert_yelp.pkl', 'wb') as f:
    dill.dump(model, f)

In [4]:
device = torch.device('cuda:0')

In [4]:
vib_classifier = TransformerVIB(768, 2, device)
hybrid_model = TransformerHybridModel(model, vib_classifier, device, fc_name='classifier')

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
# wrapper_model = HuggingFaceModelWrapper(model, tokenizer)

adaptor = TransformerAdaptor(hybrid_model)

wrapper = HuggingFaceModelWrapper(adaptor, tokenizer)

In [163]:
with open('/D/models/bert/bert_yelp_vib.pkl', 'wb') as f:
    dill.dump(hybrid_model, f)
    
with open('/D/models/bert/bert_yelp_vib_wrapper.pkl', 'wb') as f:
    dill.dump(wrapper, f)
    
with open('/D/models/bert/bert_yelp_vib_adaptor.pkl', 'wb') as f:
    dill.dump(adaptor, f)

### Yelp dataset

In [3]:
# Load Yelp Polarity dataset
dataset = load_dataset('yelp_polarity')

Reusing dataset yelp_polarity (/home/nir/.cache/huggingface/datasets/yelp_polarity/plain_text/1.0.0/14f90415c754f47cf9087eadac25823a395fef4400c7903c5897f55cfaaa6f61)


  0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
def encode(examples):
    return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)

dataset = dataset.map(encode, batched=True)
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

dataset.save_to_disk('/D/datasets/yelp/')

train_dataset = dataset['train']
test_dataset = dataset['test']

Loading cached processed dataset at /home/nir/.cache/huggingface/datasets/yelp_polarity/plain_text/1.0.0/14f90415c754f47cf9087eadac25823a395fef4400c7903c5897f55cfaaa6f61/cache-f43dc4f27a40156d.arrow
Loading cached processed dataset at /home/nir/.cache/huggingface/datasets/yelp_polarity/plain_text/1.0.0/14f90415c754f47cf9087eadac25823a395fef4400c7903c5897f55cfaaa6f61/cache-85110606db85e919.arrow


In [8]:
# test_dataset = load_from_disk('/D/datasets/yelp/test')

In [13]:
class LogitsDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx]
        return x, y

def get_transformer_logits_dataloader(model, original_loader, device, batch_size=8):
    logits_data_list = []
    logits_labels_list = []
    with torch.no_grad():
        for data_dict in tqdm(original_loader):
            x = data_dict['input_ids']
            y = data_dict['label']
            output = model(x.to(device))
            logits = output.logits
            logits_data_list.append(logits.to(torch.device('cpu')))
            logits_labels_list.append(y.to(torch.device('cpu')))

    logits_data_set = LogitsDataset(torch.concat(logits_data_list), torch.concat(logits_labels_list))
    logits_dataloader = DataLoader(logits_data_set, batch_size=batch_size, shuffle=True)

    return logits_dataloader

In [8]:
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=128)
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=128)

In [37]:
model.classifier = torch.nn.Identity()
DEVICE = torch.device('cuda:0')
model.to(DEVICE)
logits_trian_dataloader = get_transformer_logits_dataloader(model, train_dataloader, batch_size=64, device=DEVICE)
logits_test_dataloader = get_transformer_logits_dataloader(model, test_dataloader, batch_size=64, device=DEVICE)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4375/4375 [4:04:53<00:00,  3.36s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 297/297 [16:40<00:00,  3.37s/it]


In [15]:
LOGITS_TRAIN_DATALOADER_PATH = '/D/datasets/yelp/logits_dataloaders/logits_train_dataloader.pkl'
LOGITS_TEST_DATALOADER_PATH = '/D/datasets/yelp/logits_dataloaders/logits_test_dataloader.pkl'

In [40]:
# Save
import pickle5
with open(LOGITS_TRAIN_DATALOADER_PATH, 'wb') as f:
    pickle5.dump(logits_trian_dataloader, f)
with open(LOGITS_TEST_DATALOADER_PATH, 'wb') as f:
    pickle5.dump(logits_test_dataloader, f)
print('Saved dataloaders!')

Saved dataloaders!


In [16]:
import pickle5
with open(LOGITS_TRAIN_DATALOADER_PATH, 'rb') as f:
    logits_trian_dataloader = pickle5.load(f)
with open(LOGITS_TEST_DATALOADER_PATH, 'rb') as f:
    logits_test_dataloader = pickle5.load(f)

In [17]:
device = torch.device('cuda:0')
vib_classifier = TransformerVIB(768, 2, device)

In [5]:
model = BertForSequenceClassification.from_pretrained('textattack/bert-base-uncased-yelp-polarity')
DEVICE = torch.device('cuda:0')
_ = model.to(DEVICE)

In [6]:
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=32)

In [7]:
def test_model(model, test_data_loader, device):
    model.eval()
    total_correct = 0
    total_incorrect = 0
    with torch.no_grad():
        for data_dict in tqdm(test_data_loader):
            x = data_dict['input_ids']
            y = data_dict['label']
            output = model(x.to(device))
            logits = output['logits']
            predictions = torch.argmax(torch.softmax(logits, dim=-1), dim=1)#.to(torch.device('cpu'))
            correct_classifications = sum(predictions == y.to(device))
            incorrect_classifications = len(x) - correct_classifications
            total_correct += correct_classifications
            total_incorrect += incorrect_classifications
    model.train()
    print(f"acc: {print(total_correct / (total_correct + total_incorrect))}")

In [8]:
test_model(model, test_dataloader, DEVICE)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1188/1188 [15:03<00:00,  1.31it/s]

tensor(0.8969, device='cuda:0')
acc: None





In [26]:
with open('/D/models/pretrained/pretrained_bert_yelp.pkl', 'rb') as f:
    m = pickle.load(f)
# torch.load('/D/models/pretrained/pretrained_bert_yelp.pkl')

In [14]:
model = BertForSequenceClassification.from_pretrained('textattack/bert-base-uncased-yelp-polarity')
model.eval()
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

dataset = HuggingFaceDataset("yelp_polarity", None, "test")

model_wrapper = HuggingFaceModelWrapper(model, tokenizer)

attack = BAEGarg2019.build(model_wrapper)
attack_args = AttackArgs(
    num_examples=250,
    disable_stdout=True,
    silent=True,
    enable_advance_metrics=True
)

attacker = Attacker(attack, dataset, attack_args)
results_iterable = attacker.attack_dataset()
attack_log_manager = attacker.attack_log_manager
attack_log_manager.log_summary()

attack_success_stats = AttackSuccessRate().calculate(attack_log_manager.results)
words_perturbed_stats = WordsPerturbed().calculate(attack_log_manager.results)
attack_query_stats = AttackQueries().calculate(attack_log_manager.results)
acc_under_attack = str(attack_success_stats["attack_accuracy_perc"])
avg_pertrubed_words_prct = str(words_perturbed_stats["avg_word_perturbed_perc"])


Reusing dataset yelp_polarity (/home/nir/.cache/huggingface/datasets/yelp_polarity/plain_text/1.0.0/14f90415c754f47cf9087eadac25823a395fef4400c7903c5897f55cfaaa6f61)


  0%|          | 0/2 [00:00<?, ?it/s]

textattack: Loading [94mdatasets[0m dataset [94myelp_polarity[0m, split [94mtest[0m.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
textattack: Unknown if model of class <class 'transformers.models.bert.modeling_bert.BertForSequenceClassification'> compatible with goal function <class 'textattack.goal_functions.classification.untargeted_classification.UntargetedClassification'>.
[Succeeded / Failed / Skip




Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (32148 > 1024). Running this sequence through the model will result in indexing errors





Token indices sequence length is longer than the specified maximum sequence length for this model (32148 > 1024). Running this sequence through the model will result in indexing errors


In [15]:
summary_table_rows = [
    [
        "Number of successful attacks:",
        attack_success_stats["successful_attacks"],
    ],
    ["Number of failed attacks:", attack_success_stats["failed_attacks"]],
    ["Number of skipped attacks:", attack_success_stats["skipped_attacks"]],
    [
        "Original accuracy:",
        str(attack_success_stats["original_accuracy"]) + "%",
    ],
    [
        "Accuracy under attack:",
        str(attack_success_stats["attack_accuracy_perc"]) + "%",
    ],
    [
        "Attack success rate:",
        str(attack_success_stats["attack_success_rate"]) + "%",
    ],
    [
        "Average perturbed word %:",
        str(words_perturbed_stats["avg_word_perturbed_perc"]) + "%",
    ],
    [
        "Average num. words per input:",
        words_perturbed_stats["avg_word_perturbed"],
    ],
]

In [16]:
attack_success_stats["original_accuracy"]

98.4

In [17]:
for row in summary_table_rows:
    print(row)

['Number of successful attacks:', 148]
['Number of failed attacks:', 98]
['Number of skipped attacks:', 4]
['Original accuracy:', '98.4%']
['Accuracy under attack:', '39.2%']
['Attack success rate:', '60.16%']
['Average perturbed word %:', '7.01%']
['Average num. words per input:', 136.34]


In [11]:
attack_log_manager.log_summary_rows(
    summary_table_rows, "Attack Results", "attack_results_summary"
)