a self-supervised sentence embedding framework that enhances both generalization and robustness benchmarks
python -m torch.distributed.launch --nproc_per_node 4 --master_port $(expr $RANDOM + 1000) train2.py \
--model_name_or_path bert-base-uncased \
--train_file data/wiki1m_for_simcse.txt \
--output_dir /result/SimSCE12_bert \
--num_train_epochs 1 \
--per_device_train_batch_size 64 \
--learning_rate 3e-5 \
--max_seq_length 32 \
--evaluation_strategy steps \
--metric_for_best_model eval_avg_sts \
--load_best_model_at_end \
--eval_steps 125 \
--pooler_type cls \
--mlp_only_train \
--overwrite_output_dir \
--temp 0.05 \
--do_train \
--fp16 \
python -m torch.distributed.launch --nproc_per_node 4 --master_port $(expr $RANDOM + 1000) train2.py \
--model_name_or_path /result/SimSCE12_bert \
--train_file data/wiki1m_for_simcse.txt \
--output_dir /result/SimSCE12_bert \
--num_train_epochs 1 \
--per_device_train_batch_size 64 \
--learning_rate 3e-5 \
--max_seq_length 32 \
--evaluation_strategy steps \
--metric_for_best_model eval_avg_sts \
--load_best_model_at_end \
--eval_steps 125 \
--pooler_type cls \
--mlp_only_train \
--overwrite_output_dir \
--temp 0.05 \
--do_eval \
--fp16 \
The following code snippet evaluates the RobustEmbed embeddings using the TextFooler adversarial attack for the IMDB task. Users can switch to different adversarial attacks by uncommenting the corresponding attack technique in the code. Additionally, users can load another dataset (e.g., sst2 or cola) to assess the embeddings for a different task.
import textattack
import random
import transformers
import datasets
from adversarial_fine_tunning import BertForAT
from datasets import load_dataset
mnli_dataset = load_dataset('imdb') #load different dataset
train_dataset = textattack.datasets.HuggingFaceDataset(mnli_dataset['train'].shuffle())
eval_dataset = textattack.datasets.HuggingFaceDataset(mnli_dataset['test'].shuffle())
model_name = '/result/SimSCE12_bert'
config = transformers.AutoConfig.from_pretrained(pretrained_model_name_or_path = model_name, num_labels=num_labels)
model = BertForAT.from_pretrained(pretrained_model_name_or_path = model_name, config=config)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, do_lower_case= True)
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
training_args = textattack.TrainingArgs(
num_epochs=3,
parallel=True,
learning_rate=5e-5, #1e-5
per_device_train_batch_size=32,
gradient_accumulation_steps=8,
log_to_tb=True,
)
trainer = textattack.Trainer(
model_wrapper,
"classification", # regression, classification
None,
train_dataset,
eval_dataset,
training_args
)
trainer.train()
#attack = textattack.attack_recipes.PWWSRen2019.build(trainer.model_wrapper)
attack = textattack.attack_recipes.TextFoolerJin2019.build(trainer.model_wrapper)
#attack = textattack.attack_recipes.TextBuggerLi2018.build(trainer.model_wrapper)
#attack = textattack.attack_recipes.BAEGarg2019.build(trainer.model_wrapper)
#attack = textattack.attack_recipes.BERTAttackLi2020.build(trainer.model_wrapper)
attack_args = textattack.AttackArgs(num_examples=1000, disable_stdout=True)
attacker = textattack.Attacker(attack, eval_dataset, attack_args)
attacker.attack_dataset()