In [1]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset


In [2]:
# hyperparams

MODEL_ID = "google/gemma-3-4b-it"

epsilon = 0.01 # max perturbation
alpha = 0.002 # step size 
steps = 3 # num steps in one attack

In [3]:
# custom pgd trainer

class PGDTrainer(Trainer):

    # custom trainer to manually update input embeddings for pgd
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        embed_module = model.get_input_embeddings()

        input_ids = inputs["input_ids"]
        labels = inputs["labels"]

        # get clean embeds and allow backprop to input embeds
        input_embeds_init = embed_module(input_ids).detach()
        input_embeds_init.requires_grad = True

        # initialize noise
        delta = torch.zeros_like(input_embeds_init).uniform_(-epsilon, epsilon).to(input_embeds_init.device)
        delta.requires_grad = True

        # adversarial loop

        for _ in range(steps):

            delta.requires_grad_(True)

            perturbed_embeds = input_embeds_init + delta

            # forward pass
            outputs = model(inputs_embeds = perturbed_embeds, labels=labels) #labels = targets
            loss = outputs.loss

            # calc gradients
            loss.backward()

            # update delta embeds, don't update other weights
            with torch.no_grad():
                grad = delta.grad
                delta = (delta + alpha * grad.sign()).clamp(-epsilon, epsilon)  # move only in direction of sign, not magnitude
            
            model.zero_grad()
            delta.grad = None
            
        
        # update actual weights using perturbed embds as inputs
        final_embeds = input_embeds_init + delta.detach() # detach so grad doesn't flow back to delta
        outputs = model(inputs_embeds = final_embeds, labels = labels)

        return (outputs.loss, outputs) if return_outputs else outputs.loss

        

In [4]:
# models 

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    dtype=torch.float16,
    quantization_config=BitsAndBytesConfig(load_in_8bit=True)
)

# lora
peft_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules = ["q_proj", "v_proj"],
    task_type=TaskType.CAUSAL_LM,
    bias="none"
)

model = get_peft_model(model, peft_config)

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

In [7]:
# get dataset

dataset = load_dataset("PKU-Alignment/BeaverTails", split="30k_train")

# filter for only safe=true, we aren't doing rlhf

def filter_safe(example):
    return example["is_safe"]

safe_dataset = dataset.filter(filter_safe)

print(f"Total rows: ${len(dataset)}")
print(f"Safe row: ${len(safe_dataset)}")


def format_prompts(batch):
    output_texts = []

    for prompt, response in zip(batch["prompt"], batch["response"]):
        # convert to gemma format
        text = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n{response}<end_of_turn>"
        output_texts.append(text)
    
    # convert to tokens, form inputs and targets for self supervised traing
    encodings = tokenizer(output_texts, truncation=True, padding="max_length", max_length=256)
    encodings["labels"] = encodings["input_ids"].copy()
    return encodings

train_dataset = safe_dataset.map(format_prompts, batched=True, remove_columns=dataset.column_names)



Total rows: $27186
Safe row: $11604


Map:   0%|          | 0/11604 [00:00<?, ? examples/s]

In [8]:
# training code

args = TrainingArguments(
    output_dir = "./gemma3-gpd-robust", # save checkpoints, logs
    per_device_train_batch_size=6,  # Increased from 4 - A100 can handle this with 8-bit + LoRA
    # Note: With PGD (3 attack steps), each training step does ~3x forward/backward passes
    # Start conservative, can increase to 8-12 if no OOM errors
    gradient_accumulation_steps=3,  # Effective batch size = 18 (slightly larger than before)
    max_steps=100,
    learning_rate=2e-4,
    fp16=True,
    logging_steps=10,
    save_strategy="no",
    remove_unused_columns=False, #important so transformers doesn't drop input_embeds column
    dataloader_pin_memory=True,  # Faster data loading on GPU
    dataloader_num_workers=4,  # Parallel data loading
    # Note: gradient_checkpointing may conflict with custom PGD trainer, test first if needed
) 

trainer = PGDTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset
)

trainer.train()

model.save_pretrained("./final_robust_adapter") # final adapter weights

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Step,Training Loss
10,0.0
20,0.0
30,0.0
40,0.0
50,0.0
60,0.0
70,0.0
80,0.0
90,0.0
100,0.0
