> In language models, it is not possible to directly use
gradient-based methods to generate adversarial attacks because tokenization is not differentiable.
However, several works have attacked word embeddings (which can be viewed as the first latent state
in the network) and trained on these perturbations to improve robustness or generalization [Jiang et al.,
2019, Zhu et al., 2019, Liu et al., 2020, He et al., 2020, Kuang and Bharti, Li and Qiu, 2021, Sae-Lim
and Phoomvuthisarn, 2022, Pan et al., 2022, Schwinn et al., 2023, Geisler et al., 2024, Schwinn et al.,
2024].

In [1]:
!git clone https://github.com/thestephencasper/latent_adversarial_training.git

Cloning into 'latent_adversarial_training'...
remote: Enumerating objects: 48, done.[K
remote: Counting objects: 100% (48/48), done.[K
remote: Compressing objects: 100% (43/43), done.[K
remote: Total 48 (delta 23), reused 18 (delta 4), pack-reused 0[K
Receiving objects: 100% (48/48), 311.70 KiB | 3.21 MiB/s, done.
Resolving deltas: 100% (23/23), done.


In [None]:
!pip install -r latent_adversarial_training/requirements.txt

In [1]:
import torch
import datasets
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model
from datetime import datetime
import pickle
from tqdm import tqdm
from google.colab import userdata

from latent_adversarial_training.lat_model import LATLlamaForCausalLM
from latent_adversarial_training.lat_trainer import LATSFTTrainer, EvaluateFirstStepCallback

In [4]:
# Set your HF token in Colab Secrets with name HF_TOKEN
TOKEN=userdata.get('HF_TOKEN')

In [5]:
N_TRAIN=100
N_VAL=10

In [6]:
# https://huggingface.co/blog/llama2#how-to-prompt-llama-2
PROMPT_INDICATOR = 'Human: '
RESPONSE_INDICATOR = 'Assistant: '
_RESPONSE_INDICATOR = ' ' + RESPONSE_INDICATOR
PROMPT_PREFIX = '<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant.\n<</SYS>>\n\n'
PROMPT_SUFFIX = ' [/INST] '
RESPONSE_SUFFIX = ' </s><s>[INST] '
def parse_examples(examples):
    prompt_len = len(PROMPT_INDICATOR)
    response_len = len(RESPONSE_INDICATOR)
    suffix_len = len(PROMPT_SUFFIX)
    parsed = []
    for s in tqdm(examples):
        s = s.replace('\n', '')
        turns = []
        while True:
            prompt_end = s.find(RESPONSE_INDICATOR)
            prompt = s[prompt_len: prompt_end]
            response_end = s.find(PROMPT_INDICATOR, prompt_len)
            response = s[prompt_end + response_len: response_end]
            turns.extend([prompt, response])
            s = s[response_end:]
            if response_end == -1:
                turns[-1] += s
                break
        example = PROMPT_PREFIX
        for i in range(0, len(turns), 2):
            example += turns[i] + PROMPT_SUFFIX + turns[i+1] + RESPONSE_SUFFIX
        parsed.append(example[:-suffix_len])
    return parsed


dataset_name = "Anthropic/hh-rlhf"
dataset = load_dataset(dataset_name).shuffle(seed=42)

train_data_dict = {'text': parse_examples(dataset['train']['chosen'][:N_TRAIN])}

100%|██████████| 100/100 [00:00<00:00, 85250.08it/s]


In [7]:
train_dataset = datasets.Dataset.from_dict(train_data_dict)
good_test_dataset = datasets.Dataset.from_dict({'text': parse_examples(dataset['test']['chosen'][:N_VAL])})
bad_test_dataset = datasets.Dataset.from_dict({'text': parse_examples(dataset['test']['rejected'][:N_VAL])})

100%|██████████| 10/10 [00:00<00:00, 52167.96it/s]
100%|██████████| 10/10 [00:00<00:00, 38269.20it/s]


In [8]:
base_model = 'meta-llama/Llama-2-7b-chat-hf'
model = LATLlamaForCausalLM.from_pretrained(base_model,
                                            device_map='auto',
                                            load_in_4bit=True,
                                            token=TOKEN)

tokenizer = AutoTokenizer.from_pretrained(base_model,
                                          trust_remote_code=True,
                                          token=TOKEN)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [9]:
peft_config = LoraConfig(
    r=64,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
)

In [10]:
# Default values from latent_adversarial_training/lat.py
epochs=1
lr=5e-6
perturb_layer=0
epsilon=8
steps=6
norm_type='l2'
random_init=True
std_normalization=False
keep_in_eval=True
perturb_target='residuals' # Unused for embeddings AT

training_params = TrainingArguments(
    output_dir='./results',
    num_train_epochs=epochs,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    evaluation_strategy='steps',
    do_eval=True,
    eval_steps=0.125,
    learning_rate=lr,
    weight_decay=0.0006,
    max_grad_norm=0.25,
    max_steps=-1,
    warmup_ratio=0.03,
    lr_scheduler_type='constant',
)

In [11]:
trainer = LATSFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset={'good': good_test_dataset, 'bad': bad_test_dataset},
    dataset_text_field='text',
    max_seq_length=512,
    tokenizer=tokenizer,
    args=training_params,
    packing=False,
    perturb_layer=perturb_layer,
    epsilon=epsilon,
    steps=steps,
    norm_type=norm_type,
    random_init=random_init,
    std_normalization=std_normalization,
    keep_in_eval=keep_in_eval,
    perturb_target=perturb_target,
    peft_config=peft_config,
    callbacks=[EvaluateFirstStepCallback],
)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

In [12]:
run_id='at_layer0_eps8'
save=True

In [13]:
trainer.train()

results = {'good_val_losses': [], 'bad_val_losses': []}
for l in trainer.state.log_history:
    if 'eval_good_loss' in l.keys():
        results['good_val_losses'].append(l['eval_good_loss'])
    if 'eval_bad_loss' in l.keys():
        results['bad_val_losses'].append(l['eval_bad_loss'])

for k, v in results.items():
    print(f'{k}: {v}')

now = datetime.now()
date_time = now.strftime('%Y-%m-%d-%H-%M-%S')
print('date and time:', date_time)
with open(f'results/{run_id}.pkl', 'wb') as f:
    pickle.dump(results, f)

if save:
    new_model_name = f'models/llama-2-7b-chat-hf-{run_id}'
    trainer.model.save_pretrained(new_model_name)

print('Done')

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss,Good Loss,Bad Loss
1,No log,No log,5.749645,5.725774
7,No log,No log,5.747596,5.723502
14,No log,No log,5.74561,5.72126
21,No log,No log,5.742654,5.717906
28,No log,No log,5.738832,5.713758
35,No log,No log,5.73431,5.708934
42,No log,No log,5.730773,5.704927
49,No log,No log,5.726768,5.700471


good_val_losses: [5.749644756317139, 5.74759578704834, 5.745609760284424, 5.742653846740723, 5.738832473754883, 5.734309673309326, 5.730773448944092, 5.726767539978027]
bad_val_losses: [5.725773811340332, 5.723502159118652, 5.721259593963623, 5.7179059982299805, 5.713757514953613, 5.7089338302612305, 5.70492696762085, 5.7004714012146]
date and time: 2024-07-06-12-54-29
Done
