In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
from huggingface_hub import login
import os
from dotenv import load_dotenv

In [15]:
# hyperparams

MODEL_ID = "meta-llama/Meta-Llama-3.1-8B-Instruct"

epsilon = 0.01 # max perturbation
alpha = 0.002 # step size 
steps = 3 # num steps in one attack

In [None]:
# Hugging Face authentication
# Load token from .env file
load_dotenv()
hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
login(token=hf_token, add_to_git_credential=False)


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.svâ€¦

In [13]:
# custom pgd trainer

class PGDTrainer(Trainer):

    # custom trainer to manually update input embeddings for pgd
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        embed_module = model.get_input_embeddings()

        input_ids = inputs["input_ids"]
        labels = inputs["labels"]

        # get clean embeds and allow backprop to input embeds
        input_embeds_init = embed_module(input_ids).detach()
        input_embeds_init.requires_grad = True

        # initialize noise
        delta = torch.zeros_like(input_embeds_init).uniform_(-epsilon, epsilon).to(input_embeds_init.device)
        delta.requires_grad = True

        # adversarial loop

        for _ in range(steps):
            perturbed_embeds = input_embeds_init + delta

            # forward pass
            outputs = model(input_embeds = perturbed_embeds, labels=labels) #labels = targets
            loss = outputs.loss

            # calc gradients
            loss.backward()

            # update delta embeds, don't update other weights
            with torch.no_grad():
                delta += alpha * delta.grad.sign() # move only in direction of sign, not magnitude
                delta = torch.clamp(delta, -epsilon, epsilon) 
                delta.grad.zero_()
            
        
        # update actual weights using perturbed embds as inputs
        final_embeds = input_embeds_init + delta.detach() # detach so grad doesn't flow back to delta
        outputs = model(input_embeds = final_embeds, labels = labels)

        return (outputs.loss, outputs) if return_outputs else outputs.loss

        

In [None]:
# models 

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.float16,
    load_in_8bit=True
)

# lora
peft_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules = ["q_proj", "v_proj"],
    task_type=TaskType.CAUSAL_LM,
    bias="none"
)

model = get_peft_model(model, peft_config)

In [None]:
# get dataset

dataset = load_dataset("PKU-Alignment/BeaverTails", split="30k_train")

# filter for only safe=true, we aren't doing rlhf

def filter_safe(example):
    return example["is_safe"]

safe_dataset = dataset.filter(filter_safe)

print(f"Total rows: ${len(dataset)}")
print(f"Safe row: ${len(safe_dataset)}")


def format_prompts(batch):
    output_texts = []

    for prompt, response in zip(batch["prompt"], batch["response"]):
        # convert to llama 3 format
        text = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{response}<|eot_id|>"
        output_texts.append(text)
    
    # convert to tokens, form inputs and targets for self supervised traing
    encodings = tokenizer(output_texts, truncation=True, padding="max_length", max_length=256)
    encodings["labels"] = encodings["input_ids"].copy()
    return encodings

train_dataset = safe_dataset.map(format_prompts, batched=True, remove_columns=dataset.column_names)

    

In [None]:
# training code

args = TrainingArguments(
    output_dir = "./llam3-gpd-robust", # save checkpoints, logs
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    max_steps=100,
    learning_rate=2e-4,
    fp16=True,
    logging_steps=10,
    save_strategy="no",
    remove_unused_columns=False #important so transformers doesn't drop input_embeds column
) 

trainer = PGDTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset
)

trainer.train()

model.save_pretrained("./final_robust_adapter") # final adapter weights