In [None]:
import warnings
warnings.simplefilter("ignore")

In [None]:
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model

import opacus
from opacus.validators import ModuleValidator
from opacus.utils.batch_memory_manager import BatchMemoryManager
from opacus import PrivacyEngine

import torch
import torch.nn as nn
import numpy as np

from tqdm.notebook import tqdm
from torch.optim import SGD
from torch.utils.data import DataLoader

from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig

from sklearn.metrics import accuracy_score

In [None]:
model_name = "prajjwal1/bert-tiny"
EPOCHS = 22
BATCH_SIZE = 256
LR = 0.0001

In [None]:
# Prepare data
dataset = load_dataset("glue", "sst2")
num_labels = dataset["train"].features["label"].num_classes

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
tokenized_dataset = dataset.map(
    lambda example: tokenizer(example["sentence"], max_length=128, padding='max_length', truncation=True),
    batched=True
)

In [None]:
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

tokenized_dataset = tokenized_dataset.remove_columns(['idx'])
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")

In [None]:
tokenized_dataset

In [None]:
train_dataloader = DataLoader(tokenized_dataset["train"], shuffle=False, batch_size=BATCH_SIZE)
test_dataloader = DataLoader(tokenized_dataset["validation"], shuffle=False, batch_size=BATCH_SIZE)

In [None]:
EPSILON = 8.0
DELTA = 1/len(train_dataloader)
MAX_GRAD_NORM = 0.4
MAX_PHYSICAL_BATCH_SIZE = int(BATCH_SIZE/4)

In [None]:
config = AutoConfig.from_pretrained(model_name)
config.num_labels = num_labels

model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)

peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS, 
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    target_modules = ['query', 'key', 'value'],
)

if peft_config is not None:
    model = get_peft_model(model, peft_config)
    model.register_full_backward_hook(True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
errors = ModuleValidator.validate(model, strict=False)
print(errors)

In [None]:
optimizer = SGD(params=model.parameters(), lr=LR)

In [None]:
privacy_engine = PrivacyEngine(accountant="rdp")

model, optimizer, train_dataloader = privacy_engine.make_private_with_epsilon(
    module=model,
    optimizer=optimizer,
    data_loader=train_dataloader,
    epochs=EPOCHS,
    target_epsilon=EPSILON,
    target_delta=DELTA,
    max_grad_norm=MAX_GRAD_NORM,
    batch_first=True,
)

In [None]:
print(f"Using Sigma = {optimizer.noise_multiplier:.3f} | C = {optimizer.max_grad_norm} | Initial DP (ε, δ) = ({privacy_engine.get_epsilon(DELTA)}, {DELTA})")

In [None]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"Trainable Parameters: {trainable_params} || All Parameters: {all_param} || Trainable Parameters (%): {100 * trainable_params / all_param:.2f}"
    )

print_trainable_parameters(model)

In [None]:
def train(model, train_dataloader, optimizer, epoch, device):
    model.train()
    criterion = nn.CrossEntropyLoss()

    losses = []
    epsilon = []

    with BatchMemoryManager(
        data_loader=train_dataloader, 
        max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,
        optimizer=optimizer,
        ) as memory_safe_data_loader:

        for i, batch in tqdm(enumerate(memory_safe_data_loader), total=(len(train_dataloader)*4), desc=f"Training Epoch: {epoch}"):
            
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()

            outputs = model(**batch)
            loss = criterion(outputs.logits, batch["labels"])
            loss.backward()

            optimizer.step()
            losses.append(loss.item())

            if i % 8000 == 0:
                epsilon = privacy_engine.get_epsilon(DELTA)

                print(f"Training Epoch: {epoch} | Loss: {np.mean(losses):.6f} | ε = {epsilon:.2f}")                    

In [None]:
def test(model, test_dataloader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()

    losses = []
    accuracies = []

    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Test"):
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(**batch)
            loss = criterion(outputs.logits, batch["labels"])

            preds = outputs.logits.argmax(dim=-1)
            acc = accuracy_score(preds.cpu().numpy(), batch["labels"].cpu().numpy())

            losses.append(loss.item())
            accuracies.append(acc.item())

    acc = np.mean(accuracies)
    loss = np.mean(losses)

    print(
        f"Test set: Loss: {loss:.4f}, Accuracy: {acc*100:.2f}%"
    )

    return loss, acc

In [None]:
for epoch in tqdm(range(EPOCHS), desc=f'Training {EPOCHS} Epochs'):
    train(model, train_dataloader, optimizer, epoch + 1, device)

In [None]:
final_epsilon = privacy_engine.get_epsilon(DELTA)
print(f"Final DP Guarantee (ε, δ)-DP = ({final_epsilon:.2f}, {DELTA})")

In [None]:
test(model, test_dataloader, device)