# Private Fine-Tuning of VaultGemma with LoRA and Differential Privacy

This notebook demonstrates how to fine-tune Google's VaultGemma 1B model on medical data using:
- **LoRA (Low-Rank Adaptation)**: Efficient parameter-efficient fine-tuning
- **4-bit Quantization**: Reduced memory footprint using BitsAndBytes
- **Differential Privacy**: Privacy-preserving training with Opacus

The goal is to create a medical Q&A model while maintaining strong privacy guarantees.
%% [markdown]
## 1. Import Libraries and Load Dataset

We start by importing all necessary libraries:
- `transformers`: For model and tokenizer
- `peft`: For LoRA adapters
- `opacus`: For differential privacy
- `datasets`: For loading and processing the medical dataset

The dataset used is **Medical Meadow Medical Flashcards**, which contains medical question-answer pairs.


In [None]:
# 1. Install necessary libraries
!pip install -q -U peft accelerate bitsandbytes datasets pandas
!pip install git+https://github.com/huggingface/transformers@v4.56.1-Vault-Gemma-preview
!pip install kagglehub ipywidgets opacus -q

'\n!pip install git+https://github.com/huggingface/transformers@v4.56.1-Vault-Gemma-preview\n! pip install kagglehub\n! pip install ipywidgets\n! pip install protobuf -q\n! pip install tiktoken -q\n! pip install blobfile -q\n! pip install sentencepiece -q\n!pip install -q opacus\n'

In [2]:
import os
import math
import torch
import pandas as pd
import kagglehub
from datasets import load_dataset, Dataset
from transformers import (
    AutoModelForCausalLM,
    GemmaTokenizer,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig,
    get_cosine_schedule_with_warmup
)
from peft import LoraConfig, get_peft_model
from opacus import PrivacyEngine
from opacus.validators import ModuleValidator
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from peft import PeftModel

# Load medical dataset
medical_data = load_dataset("medalpaca/medical_meadow_medical_flashcards", split="train")
NUM_SAMPLES = 1000
data = medical_data.to_pandas().head(NUM_SAMPLES)
print(data.iloc[0])

input          What is the relationship between very low Mg2+...
output         Very low Mg2+ levels correspond to low PTH lev...
instruction                      Answer this question truthfully
Name: 0, dtype: object


## 2. Load Base Model with 4-bit Quantization

We load **VaultGemma 1B** from Kaggle with 4-bit quantization to reduce memory usage:
- **NF4 quantization**: Normal Float 4-bit quantization for optimal quality
- **Double quantization**: Further compression by quantizing the quantization constants
- **bfloat16 compute**: Uses brain floating point for stable training

The model is automatically distributed across available GPUs using `device_map="auto"`.

In [3]:
# Download model from Kaggle
model_path = kagglehub.model_download("google/vaultgemma/transformers/1b")

# Configure 4-bit quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=quantization_config,
    dtype=torch.bfloat16,
    device_map="auto",
)

# Load tokenizer
tokenizer = GemmaTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

## 3. Apply LoRA Adapters

**LoRA (Low-Rank Adaptation)** adds small trainable matrices to specific layers while keeping the base model frozen:
- **r=8**: Rank of the low-rank matrices (higher = more capacity but more parameters)
- **lora_alpha=16**: Scaling factor for LoRA weights
- **target_modules**: Which attention and MLP layers to adapt (all projection layers in Gemma)
- **lora_dropout=0.05**: Dropout for regularization

This approach trains only ~1-2% of the total parameters, making training much faster and memory-efficient.


In [4]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply LoRA adapters to the model (for new training)
peft_model = get_peft_model(model, lora_config)

print("Model and LoRA adapters loaded for training!")
peft_model.print_trainable_parameters()

# Set model to training mode
peft_model.train()

Model and LoRA adapters loaded for training!
trainable params: 6,842,368 || all params: 1,045,583,488 || trainable%: 0.6544


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): VaultGemmaForCausalLM(
      (model): VaultGemmaModel(
        (embed_tokens): Embedding(256000, 1152, padding_idx=0)
        (layers): ModuleList(
          (0-25): 26 x VaultGemmaDecoderLayer(
            (self_attn): VaultGemmaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=1152, out_features=1024, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1152, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
          

## 4. Prepare and Tokenize Dataset

We create a custom tokenization function that:
1. Formats each example as an instruction-following prompt
2. Tokenizes the full text (question + answer)
3. **Masks the prompt tokens** in the labels by setting them to -100

This ensures the model only learns to generate the **response**, not to repeat the question.
The masking prevents the loss function from penalizing the model for the input prompt.


In [5]:
def tokenize_and_mask(samples):
    """
    Tokenizes the input and output, then masks the prompt tokens in labels
    so that the model only learns to predict the response.
    """
    # Format prompts and responses
    full_prompts = [
        f"Instruction:\nAnswer this question truthfully.\n\nQuestion:\n{inp}" 
        for inp in samples["input"]
    ]
    responses = [f"\n\nResponse:\n{out}" for out in samples["output"]]
    
    # Tokenize full text (prompt + response)
    model_inputs = tokenizer(
        [p + r for p, r in zip(full_prompts, responses)],
        truncation=True,
        max_length=128,
        padding="max_length",
        return_tensors="pt"
    )
    
    # Tokenize only prompts to determine their length
    prompt_tokens = tokenizer(
        full_prompts,
        truncation=True,
        max_length=128,
        padding="max_length",
        return_tensors="pt"
    )
    
    # Create labels (copy of input_ids)
    labels = model_inputs["input_ids"].clone()
    
    # Mask prompt tokens in labels (set to -100 so they're ignored in loss calculation)
    for i in range(len(labels)):
        prompt_len = int(prompt_tokens["attention_mask"][i].sum())
        labels[i][:prompt_len] = -100
    
    model_inputs["labels"] = labels
    return model_inputs

# Convert pandas DataFrame to Dataset
dataset = Dataset.from_pandas(data)

# Apply tokenization function
tokenized_dataset = dataset.map(
    tokenize_and_mask,
    batched=True,
    remove_columns=dataset.column_names
)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

## 5. Configure Training Parameters

We set up all training hyperparameters and create data loaders:
- **90/10 train/validation split** for monitoring overfitting
- **Batch size = 1** with **gradient accumulation = 8** (effective batch size of 8)
- **Learning rate = 2e-5** with cosine decay schedule
- **20 epochs** of training

The small batch size is necessary due to memory constraints from the quantized model.


In [6]:
# Split dataset into train and validation
train_size = int(0.9 * len(tokenized_dataset))
train_dataset = tokenized_dataset.select(range(train_size))
eval_dataset = tokenized_dataset.select(range(train_size, len(tokenized_dataset)))

# Training hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
num_train_epochs = 2
per_device_train_batch_size = 1
gradient_accumulation_steps = 8
learning_rate = 2e-5
eval_steps = 100
logging_steps = 40

# Initialize optimizer
optimizer = torch.optim.AdamW(peft_model.parameters(), lr=learning_rate)

# Data collator for language modeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# Create data loaders
train_dataloader = DataLoader(
    train_dataset, 
    batch_size=per_device_train_batch_size, 
    shuffle=True,
    collate_fn=data_collator
)
eval_dataloader = DataLoader(
    eval_dataset, 
    batch_size=per_device_train_batch_size,
    collate_fn=data_collator
)

## 6. Apply Differential Privacy with Opacus

**Differential Privacy (DP)** ensures that the trained model doesn't memorize specific training examples:
- **ε (epsilon) = 8.0**: Privacy budget (lower = more privacy, but potentially worse performance)
- **δ (delta) = 1e-5**: Probability of privacy breach (should be < 1/dataset_size)
- **max_grad_norm = 1.0**: Clips gradients to prevent any single example from having too much influence

Opacus modifies the training loop to add calibrated noise to gradients, providing mathematical privacy guarantees.
The final epsilon value will tell us exactly how much privacy was consumed during training.


In [None]:
# Differential privacy setup
target_delta = 1e-5   # Lower values = more privacy
target_epsilon = 15.0 # Lower values = more privacy

privacy_engine = PrivacyEngine()
peft_model, optimizer, train_dataloader = privacy_engine.make_private_with_epsilon(
    module=peft_model,
    optimizer=optimizer,
    data_loader=train_dataloader,
    target_epsilon=target_epsilon,
    target_delta=target_delta,
    epochs=num_train_epochs,
    max_grad_norm=1.0,
    poisson_sampling=False
)

# Validate and fix model for Opacus compatibility
if not ModuleValidator.is_valid(peft_model):
    peft_model = ModuleValidator.fix(peft_model)

peft_model.train()
peft_model.to(device)

# Learning rate scheduler with cosine decay and warmup
num_training_steps = math.ceil(len(train_dataloader) / gradient_accumulation_steps) * num_train_epochs
num_warmup_steps = 40

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

print("Using cosine learning rate schedule with warmup.")



Using cosine learning rate schedule with warmup.


## 8. Training Loop

The main training loop with the following features:
- **Gradient accumulation**: Accumulates gradients over 8 steps before updating weights
- **Automatic checkpointing**: Saves model when training loss drops below 0.06
- **Periodic validation**: Evaluates on validation set every 200 steps
- **Progress tracking**: Uses tqdm for visual progress bar

The loop will run for 2 epochs, logging metrics every 20 steps and evaluating every 200 steps.
Models are saved to the specified directory when performance thresholds are met.


In [None]:
print("Starting training loop...")
progress_bar = tqdm(range(num_training_steps))
global_step = 0

for epoch in range(num_train_epochs):
    peft_model.train()
    train_loss_accumulator = 0.0
    
    for step, batch in enumerate(train_dataloader):
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Forward pass
        outputs = peft_model(**batch)
        loss = outputs.loss
        train_loss_accumulator += loss.item()
        
        # Backward pass
        loss.backward()
        
        # Optimizer step with gradient accumulation
        if (step + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            
            global_step += 1
            progress_bar.update(1)
            
            # Logging and checkpoint saving
            if global_step % logging_steps == 0:
                avg_train_loss = train_loss_accumulator / logging_steps
                log_message = f"Step {global_step}: Train Loss = {avg_train_loss:.4f}"
                
                # Save checkpoint if loss is below threshold
                if avg_train_loss < 0.06:
                    checkpoint_path = "./final_model"
                    
                    # Save PEFT adapters and tokenizer
                    peft_model.save_pretrained(checkpoint_path)
                    tokenizer.save_pretrained(checkpoint_path)
                    log_message += f" | Model Saved to {checkpoint_path}"
                
                # Validation evaluation
                if global_step % eval_steps == 0:
                    peft_model.eval()
                    eval_losses = []
                    
                    with torch.no_grad():
                        for eval_batch in eval_dataloader:
                            eval_batch = {k: v.to(device) for k, v in eval_batch.items()}
                            eval_outputs = peft_model(**eval_batch)
                            eval_losses.append(eval_outputs.loss.item())
                    
                    avg_eval_loss = sum(eval_losses) / len(eval_losses)
                    log_message += f" | Validation Loss = {avg_eval_loss:.4f}"
                    peft_model.train()
                
                print(log_message)
                train_loss_accumulator = 0.0

# Final privacy budget
epsilon = privacy_engine.get_epsilon(delta=target_delta)
print(f"Final privacy cost: ε = {epsilon:.2f} for δ = {target_delta}")

Starting training loop...


  0%|          | 0/226 [00:00<?, ?it/s]



Step 40: Train Loss = 23.2601
Step 80: Train Loss = 22.7426
Step 120: Train Loss = 4.4020
Step 160: Train Loss = 22.0963
Step 200: Train Loss = 21.9594 | Validation Loss = 2.8816
Final privacy cost: ε = 22.21 for δ = 0.01


In [None]:
model_path = kagglehub.model_download("google/vaultgemma/transformers/1b")

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

base_model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=quantization_config,
    dtype=torch.bfloat16,
    device_map="auto",
)

adapter_path = "./final_model"

tokenizer = GemmaTokenizer.from_pretrained(adapter_path)
tokenizer.pad_token = tokenizer.eos_token

peft_model = PeftModel.from_pretrained(base_model, adapter_path, is_trainable=False)

peft_model.eval()



PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): VaultGemmaForCausalLM(
      (model): VaultGemmaModel(
        (embed_tokens): Embedding(256000, 1152, padding_idx=0)
        (layers): ModuleList(
          (0-25): 26 x VaultGemmaDecoderLayer(
            (self_attn): VaultGemmaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=1152, out_features=1024, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=1152, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
           

In [None]:
def generate_response(question, max_new_tokens=128, temperature=0.9, top_p=0.9):
    prompt = f"Instruction:\nAnswer this question truthfully.\n\nQuestion:\n{question}\n\nResponse:\n"
    
    inputs = tokenizer(prompt, return_tensors="pt", padding=True)
    inputs = {k: v.to(peft_model.device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = peft_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.4 
        )
    
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    if "Response:" in full_response:
        response = full_response.split("Response:")[-1].strip()
    else:
        response = full_response
    
    return response


question = "What is the role of insulin in the human body?"

response = generate_response(question)
print(f"\nQuestion: {question}")
print(f"Answer: {response}")



Question: What is the role of insulin in the human body?
Answer: The hormone affects several important functions including its actions on blood glucose levels which control hormones to lower or maintain it so that we can feel more adequate and are able respond better
