In [1]:
! pip install opacus



In [2]:
import warnings
warnings.simplefilter("ignore")
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from opacus.utils.batch_memory_manager import BatchMemoryManager
from opacus import PrivacyEngine
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader, random_split
from opacus.validators import ModuleValidator
from tqdm import tqdm


In [6]:
### Dataset preparations
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD_DEV = (0.2023, 0.1994, 0.2010)


# CIFAR-10 dataset loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD_DEV),
])

train_dataset = datasets.CIFAR10(root='./data/cifar10', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data/cifar10', train=False, download=True, transform=transform)
#train_size = int(0.8 * len(train_dataset))
#val_size = len(train_dataset) - train_size
#train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

### Hyperparameters
MAX_GRAD_NORM = 1.2
EPSILON = 50.0
DELTA = 1e-5
EPOCHS = 20
LR = 1e-3
BATCH_SIZE = 512
MAX_PHYSICAL_BATCH_SIZE = 32

Files already downloaded and verified
Files already downloaded and verified


In [7]:
def accuracy(preds, labels):
    return (preds == labels).mean()

In [9]:

##############################################################################################
# Image classifier WITH differential privacy: Model, Optimizer, DataLoader
##############################################################################################

train_loader_dp = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
#val_loader_dp = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader_dp = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

model_dp = models.resnet18(num_classes=10)
model_dp = ModuleValidator.fix(model_dp)
ModuleValidator.validate(model_dp, strict=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_dp = model_dp.to(device)

optimizer_dp = optim.SGD(model_dp.parameters(), lr=LR)


print(
    f"Before make_private(). "
    f"Model:{type(model_dp)}, Optimizer:{type(optimizer_dp)}, DataLoader:{type(train_loader_dp)}"
)

# Now we initialize a privacy engine and attach it
privacy_engine = PrivacyEngine()
model_dp, optimizer_dp, train_loader_dp = privacy_engine.make_private_with_epsilon(
    module=model_dp,
    optimizer=optimizer_dp,
    data_loader=train_loader_dp,
    epochs=EPOCHS,
    target_epsilon=EPSILON,
    target_delta=DELTA,
    max_grad_norm=MAX_GRAD_NORM,
)
print(f"Using sigma={optimizer_dp.noise_multiplier} and C={MAX_GRAD_NORM}")

print("="*20)

print(
    f"After make_private(). "
    f"Model:{type(model_dp)}, Optimizer:{type(optimizer_dp)}, DataLoader:{type(train_loader_dp)}"
)
# Now it's business as usual

Before make_private(). Model:<class 'torchvision.models.resnet.ResNet'>, Optimizer:<class 'torch.optim.sgd.SGD'>, DataLoader:<class 'torch.utils.data.dataloader.DataLoader'>
Using sigma=0.37197113037109375 and C=1.2
After make_private(). Model:<class 'opacus.grad_sample.grad_sample_module.GradSampleModule'>, Optimizer:<class 'opacus.optimizers.optimizer.DPOptimizer'>, DataLoader:<class 'opacus.data_loader.DPDataLoader'>


In [10]:
##############################################################################################
# Image classifier WITH differential privacy: Train and test method
##############################################################################################

def train_dp(model_dp, train_loader_dp, optimizer_dp, epoch, device):
    model_dp.train()
    criterion_dp = nn.CrossEntropyLoss()
    losses = []
    top1_acc = []
        
    with BatchMemoryManager(
        data_loader=train_loader_dp, 
        max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE, 
        optimizer=optimizer_dp
    ) as memory_safe_data_loader:
        for i, (images, target) in enumerate(memory_safe_data_loader):
            optimizer_dp.zero_grad()
            images = images.to(device)
            target = target.to(device)
            
            output = model_dp(images)
            loss = criterion_dp(output, target)

            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
            labels = target.detach().cpu().numpy()

            acc = accuracy(preds, labels)

            losses.append(loss.item())
            top1_acc.append(acc)

            loss.backward()
            optimizer_dp.step()

            if (i + 1) % 200 == 0:
                epsilon = privacy_engine.get_epsilon(DELTA)
                print(
                    f"\tTrain Epoch: {epoch} \t"
                    f"Loss: {np.mean(losses):.6f} "
                    f"Acc@1: {np.mean(top1_acc) * 100:.6f} "
                    f"(ε = {epsilon:.2f}, δ = {DELTA})"
                )


def test_dp(model_dp, test_loader_dp, device):
    model_dp.eval()
    criterion_dp = nn.CrossEntropyLoss()
    losses = []
    top1_acc = []

    with torch.no_grad():
        for images, target in test_loader_dp:
            images = images.to(device)
            target = target.to(device)

            output = model_dp(images)
            loss = criterion_dp(output, target)

            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
            labels = target.detach().cpu().numpy()

            acc = accuracy(preds, labels)

            losses.append(loss.item())
            top1_acc.append(acc)

    top1_avg = np.mean(top1_acc)


    print(
        f"\tTest set:"
        f"Loss: {np.mean(losses):.6f} "
        f"Acc: {top1_avg * 100:.6f} "
    )
    return np.mean(top1_acc)


for epoch in tqdm(range(EPOCHS), desc="Epoch", unit="epoch"):
    train_dp(model_dp, train_loader_dp, optimizer_dp, epoch + 1, device)

top1_acc = test_dp(model_dp, test_loader_dp, device)

Epoch:   0%|          | 0/20 [00:00<?, ?epoch/s]

	Train Epoch: 1 	Loss: 2.565158 Acc@1: 9.443884 (ε = 9.14, δ = 1e-05)
	Train Epoch: 1 	Loss: 2.556096 Acc@1: 9.229973 (ε = 10.64, δ = 1e-05)
	Train Epoch: 1 	Loss: 2.543147 Acc@1: 9.130354 (ε = 11.65, δ = 1e-05)


In [21]:

##############################################################################################
# Image classifier WITHOUT differential privacy: Model, Optimizer, DataLoader
##############################################################################################

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
# val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

model = models.resnet18(num_classes=10)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = optim.SGD(model.parameters(), lr=LR)

In [27]:
##############################################################################################
# Image classifier WITHOUT differential privacy: Train and test method
##############################################################################################

def train(model, train_loader, optimizer, epoch, device):
    model.train()
    criterion = nn.CrossEntropyLoss()
    losses = []
    top1_acc = []

    for i, (images, target) in enumerate(train_loader):
        optimizer.zero_grad()
        images = images.to(device)
        target = target.to(device)
        
        output = model(images)
        loss = criterion(output, target)

        preds = np.argmax(output.detach().cpu().numpy(), axis=1)
        labels = target.detach().cpu().numpy()

        acc = accuracy(preds, labels)

        losses.append(loss.item())
        top1_acc.append(acc)

        loss.backward()
        optimizer.step()

        print("wtf: ", i)
        if (i + 1) % 200 == 0:
            print(
                f"\tTrain Epoch: {epoch} \t"
                f"Loss: {np.mean(losses):.6f} "
                f"Acc@1: {np.mean(top1_acc) * 100:.6f} "
            )


def test(model, test_loader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    losses = []
    top1_acc = []

    with torch.no_grad():
        for images, target in test_loader:
            images = images.to(device)
            target = target.to(device)

            output = model(images)
            loss = criterion(output, target)

            preds = np.argmax(output.detach().cpu().numpy(), axis=1)
            labels = target.detach().cpu().numpy()

            acc = accuracy(preds, labels)

            losses.append(loss.item())
            top1_acc.append(acc)

    top1_avg = np.mean(top1_acc)


    print(
        f"\tTest set: "
        f"Loss: {np.mean(losses):.6f} "
        f"Acc: {top1_avg * 100:.6f} "
    )
    return np.mean(top1_acc)


for epoch in tqdm(range(EPOCHS), desc="Epoch", unit="epoch"):
    train(model, train_loader, optimizer, epoch + 1, device)

top1_acc = test(model, test_loader, device)

Epoch:   0%|          | 0/20 [00:00<?, ?epoch/s]

wtf:  0
wtf:  1
wtf:  2
wtf:  3
wtf:  4
wtf:  5
wtf:  6
wtf:  7
wtf:  8
wtf:  9
wtf:  10
wtf:  11
wtf:  12
wtf:  13
wtf:  14
wtf:  15
wtf:  16
wtf:  17
wtf:  18
wtf:  19
wtf:  20
wtf:  21
wtf:  22
wtf:  23
wtf:  24
wtf:  25
wtf:  26
wtf:  27
wtf:  28
wtf:  29
wtf:  30
wtf:  31
wtf:  32
wtf:  33
wtf:  34
wtf:  35
wtf:  36
wtf:  37
wtf:  38
wtf:  39
wtf:  40
wtf:  41
wtf:  42
wtf:  43
wtf:  44
wtf:  45
wtf:  46
wtf:  47
wtf:  48
wtf:  49
wtf:  50
wtf:  51
wtf:  52
wtf:  53
wtf:  54
wtf:  55
wtf:  56
wtf:  57
wtf:  58
wtf:  59
wtf:  60
wtf:  61
wtf:  62
wtf:  63
wtf:  64
wtf:  65
wtf:  66
wtf:  67
wtf:  68
wtf:  69
wtf:  70
wtf:  71
wtf:  72
wtf:  73
wtf:  74
wtf:  75
wtf:  76
wtf:  77
wtf:  78
wtf:  79
wtf:  80
wtf:  81
wtf:  82
wtf:  83
wtf:  84
wtf:  85
wtf:  86
wtf:  87
wtf:  88
wtf:  89
wtf:  90
wtf:  91
wtf:  92
wtf:  93
wtf:  94
wtf:  95
wtf:  96
wtf:  97
wtf:  98
wtf:  99
wtf:  100
wtf:  101
wtf:  102
wtf:  103
wtf:  104
wtf:  105
wtf:  106
wtf:  107
wtf:  108
wtf:  109
wtf:  110


Epoch:   5%|▌         | 1/20 [00:07<02:19,  7.36s/epoch]

wtf:  153
wtf:  154
wtf:  155
wtf:  156
wtf:  0
wtf:  1
wtf:  2
wtf:  3
wtf:  4
wtf:  5
wtf:  6
wtf:  7
wtf:  8
wtf:  9
wtf:  10
wtf:  11
wtf:  12
wtf:  13
wtf:  14
wtf:  15
wtf:  16
wtf:  17
wtf:  18
wtf:  19
wtf:  20
wtf:  21
wtf:  22
wtf:  23
wtf:  24
wtf:  25
wtf:  26
wtf:  27
wtf:  28
wtf:  29
wtf:  30
wtf:  31
wtf:  32
wtf:  33
wtf:  34
wtf:  35
wtf:  36
wtf:  37
wtf:  38
wtf:  39
wtf:  40
wtf:  41
wtf:  42
wtf:  43
wtf:  44
wtf:  45
wtf:  46
wtf:  47
wtf:  48
wtf:  49
wtf:  50
wtf:  51
wtf:  52
wtf:  53
wtf:  54
wtf:  55
wtf:  56
wtf:  57
wtf:  58
wtf:  59
wtf:  60
wtf:  61
wtf:  62
wtf:  63
wtf:  64
wtf:  65
wtf:  66
wtf:  67
wtf:  68
wtf:  69
wtf:  70
wtf:  71
wtf:  72
wtf:  73
wtf:  74
wtf:  75
wtf:  76
wtf:  77
wtf:  78
wtf:  79
wtf:  80
wtf:  81
wtf:  82
wtf:  83
wtf:  84
wtf:  85
wtf:  86
wtf:  87
wtf:  88
wtf:  89
wtf:  90
wtf:  91
wtf:  92
wtf:  93
wtf:  94
wtf:  95
wtf:  96
wtf:  97
wtf:  98
wtf:  99
wtf:  100
wtf:  101
wtf:  102
wtf:  103
wtf:  104
wtf:  105
wtf:  106


Epoch:  10%|█         | 2/20 [00:14<02:11,  7.33s/epoch]

wtf:  155
wtf:  156
wtf:  0
wtf:  1
wtf:  2
wtf:  3
wtf:  4
wtf:  5
wtf:  6
wtf:  7
wtf:  8
wtf:  9


Epoch:  10%|█         | 2/20 [00:15<02:19,  7.75s/epoch]

wtf:  10
wtf:  11
wtf:  12
wtf:  13





KeyboardInterrupt: 