In [None]:
# dp_training.ipynb

from datasets import load_dataset
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, AdamW
from torch.utils.data import DataLoader
import torch
from opacus import PrivacyEngine

# 1. Load dataset (AG News)
dataset = load_dataset("ag_news")
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

def tokenize(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True)

dataset = dataset.map(tokenize, batched=True)
train_dataset = dataset["train"].with_format("torch")
test_dataset = dataset["test"].with_format("torch")

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)

# 2. Model
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=4)
optimizer = AdamW(model.parameters(), lr=5e-5)

# 3. Attach Opacus Privacy Engine
privacy_engine = PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=train_loader,
    noise_multiplier=1.0,   # controls noise level
    max_grad_norm=1.0,      # gradient clipping
)

# 4. Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(2):  # fewer epochs since DP training is slower
    model.train()
    for batch in train_loader:
        optimizer.zero_grad()
        inputs = {k: v.to(device) for k, v in batch.items() if k in ["input_ids", "attention_mask"]}
        labels = batch["label"].to(device)
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

    epsilon, best_alpha = privacy_engine.get_epsilon(delta=1e-5)
    print(f"Epoch {epoch+1} | DP budget (ε): {epsilon:.2f}, α: {best_alpha}")

# 5. Evaluate accuracy
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for batch in test_loader:
        inputs = {k: v.to(device) for k, v in batch.items() if k in ["input_ids", "attention_mask"]}
        labels = batch["label"].to(device)
        outputs = model(**inputs)
        preds = torch.argmax(outputs.logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy with DP: {100*correct/total:.2f}%")
