# Training with Controller-Based Privacy Engine (No Model Wrapping)

This tutorial demonstrates how to use Opacus's controller-based privacy engine (`PrivacyEngineGradSampleController`), which provides better compatibility with transformer models and other complex architectures by **avoiding model wrapping**.

## Why Controller-Based?

The standard `PrivacyEngine` wraps your model in a `GradSampleModule`, which can cause issues with:
- **Type checking**: `isinstance()` checks fail because the model is wrapped
- **State dict compatibility**: Wrapped models have `_module.` prefixes that complicate checkpoint loading
- **Complex architectures**: Models with custom `__getattr__` logic (e.g., HuggingFace transformers)

The controller-based approach attaches hooks directly to your model via a `GradSampleController` **without wrapping it**, keeping your model's type and structure intact.

## Setup

First, let's import the necessary libraries and create a simple dataset:

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from opacus.privacy_engine_gsc import PrivacyEngineGradSampleController
import warnings
warnings.simplefilter("ignore")

In [None]:
# Create a synthetic dataset
n_samples = 1000
n_features = 20
n_classes = 10

X = torch.randn(n_samples, n_features)
y = torch.randint(0, n_classes, (n_samples,))

dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

## Define a Simple Model

Let's create a simple neural network classifier:

In [None]:
class SimpleClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = SimpleClassifier(n_features, 64, n_classes)
print(f"Model type before: {type(model).__name__}")
print(f"isinstance check before: {isinstance(model, SimpleClassifier)}")

## Standard PrivacyEngine (for comparison)

Let's first see what happens with the standard `PrivacyEngine`:

In [None]:
from opacus import PrivacyEngine

# Create a fresh model for standard approach
model_standard = SimpleClassifier(n_features, 64, n_classes)
optimizer_standard = optim.Adam(model_standard.parameters(), lr=0.001)

privacy_engine_standard = PrivacyEngine()
model_standard, optimizer_standard, dataloader_standard = privacy_engine_standard.make_private(
    module=model_standard,
    optimizer=optimizer_standard,
    data_loader=dataloader,
    noise_multiplier=1.0,
    max_grad_norm=1.0,
)

print(f"\nStandard PrivacyEngine:")
print(f"Model type after: {type(model_standard).__name__}")
print(f"isinstance check after: {isinstance(model_standard, SimpleClassifier)}")
print(f"State dict keys (first 3): {list(model_standard.state_dict().keys())[:3]}")

Notice how the model is now wrapped in `GradSampleModule`, `isinstance` checks fail, and state dict keys have `_module.` prefixes.

## Controller-Based PrivacyEngine

Now let's use the controller-based approach:

In [None]:
# Create a fresh model for controller-based approach
model = SimpleClassifier(n_features, 64, n_classes)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Initialize controller-based privacy engine
privacy_engine = PrivacyEngineGradSampleController()

model, optimizer, dataloader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=dataloader,
    noise_multiplier=1.0,
    max_grad_norm=1.0,
)

print(f"\nController-Based PrivacyEngine:")
print(f"Model type after: {type(model).__name__}")
print(f"isinstance check after: {isinstance(model, SimpleClassifier)}")
print(f"State dict keys (first 3): {list(model.state_dict().keys())[:3]}")

Notice how the model **keeps its original type**, `isinstance` checks **still work**, and state dict keys are **clean without prefixes**!

## Training Loop

The training loop is identical to standard PyTorch:

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()

EPOCHS = 3
DELTA = 1e-5

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    epsilon = privacy_engine.get_epsilon(DELTA)
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch + 1}/{EPOCHS} | Loss: {avg_loss:.4f} | ε: {epsilon:.2f} (δ={DELTA})")

## Using `make_private_with_epsilon`

You can also specify a target epsilon and have the privacy engine compute the appropriate noise multiplier:

In [None]:
# Create fresh instances
model2 = SimpleClassifier(n_features, 64, n_classes)
optimizer2 = optim.Adam(model2.parameters(), lr=0.001)
dataloader2 = DataLoader(dataset, batch_size=32, shuffle=True)

privacy_engine2 = PrivacyEngineGradSampleController()

model2, optimizer2, dataloader2 = privacy_engine2.make_private_with_epsilon(
    module=model2,
    optimizer=optimizer2,
    data_loader=dataloader2,
    target_epsilon=3.0,
    target_delta=1e-5,
    epochs=EPOCHS,
    max_grad_norm=1.0,
)

print(f"Target epsilon: 3.0")
print(f"Computed noise multiplier: {privacy_engine2.noise_multiplier:.3f}")

## Checkpoint Saving and Loading

Checkpoints are easier with controller-based approach since there are no `_module.` prefixes:

In [None]:
# Save checkpoint
privacy_engine.save_checkpoint(
    path="checkpoint.pt",
    module=model,
    optimizer=optimizer,
)
print("Checkpoint saved!")

# Load checkpoint
model_loaded = SimpleClassifier(n_features, 64, n_classes)
optimizer_loaded = optim.Adam(model_loaded.parameters(), lr=0.001)

privacy_engine_loaded = PrivacyEngineGradSampleController()
privacy_engine_loaded.load_checkpoint(
    path="checkpoint.pt",
    module=model_loaded,
    optimizer=optimizer_loaded,
)
print("Checkpoint loaded!")

## Example with HuggingFace Transformers

The controller-based approach shines with transformer models:

In [None]:
# Uncomment to run with transformers
# from transformers import BertForSequenceClassification
# 
# bert_model = BertForSequenceClassification.from_pretrained(
#     "bert-base-uncased",
#     num_labels=2,
# )
# 
# optimizer = optim.AdamW(bert_model.parameters(), lr=5e-5)
# 
# privacy_engine = PrivacyEngineGradSampleController()
# bert_model, optimizer, dataloader = privacy_engine.make_private(
#     module=bert_model,
#     optimizer=optimizer,
#     data_loader=your_dataloader,
#     noise_multiplier=1.0,
#     max_grad_norm=1.0,
# )
# 
# # bert_model is still BertForSequenceClassification!
# assert isinstance(bert_model, BertForSequenceClassification)

## Key Differences Summary

| Feature | Standard PrivacyEngine | Controller-Based PrivacyEngine |
|---------|------------------------|---------------------------|
| Model wrapping | Yes (GradSampleModule) | **No** |
| Type preservation | No | **Yes** |
| `isinstance()` works | No | **Yes** |
| State dict prefixes | `_module.` prefix | **Clean** |
| Direct attribute access | Via forwarding | **Direct** |
| Transformer compatibility | Can have issues | **Better** |
| Requires cleanup | No | **Yes** |
| API | Standard | **Same** |

## When to Use Controller-Based?

Use `PrivacyEngineGradSampleController` when:
- Working with HuggingFace transformers or other models with complex `__getattr__` logic
- You need `isinstance()` checks to work correctly
- You want clean state dicts without `_module.` prefixes
- You need direct access to model attributes

Use standard `PrivacyEngine` when:
- You have simple models without complex introspection
- You don't need the benefits above
- You prefer the more battle-tested approach

## Learn More

For more details, see:
- [Opacus main documentation](https://opacus.ai)
- Other tutorials in the `tutorials/` folder