In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [2]:
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

In [3]:
transform = transforms.ToTensor()

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(".", train=True, download=True, transform=transform),
    batch_size=64,
    shuffle=True
)

criterion = nn.CrossEntropyLoss()

In [None]:
baseline_model = SimpleModel().to(device)
optimizer = optim.Adam(baseline_model.parameters(), lr=0.001)

baseline_model.train()
for epoch in range(2):
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(baseline_model(x), y)
        loss.backward()
        optimizer.step()

baseline_model.eval()
print("Baseline model trained")

In [None]:
dp_model = SimpleModel().to(device)
optimizer = optim.Adam(dp_model.parameters(), lr=0.001)

CLIP_NORM = 1.0
NOISE_MULTIPLIER = 1.1

dp_model.train()
for epoch in range(2):
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        loss = criterion(dp_model(x), y)
        loss.backward()

        total_norm = torch.norm(
            torch.stack([
                torch.norm(p.grad.detach())
                for p in dp_model.parameters()
                if p.grad is not None
            ])
        )

        clip_coef = CLIP_NORM / (total_norm + 1e-6)
        if clip_coef < 1:
            for p in dp_model.parameters():
                if p.grad is not None:
                    p.grad.mul_(clip_coef)

        for p in dp_model.parameters():
            if p.grad is not None:
                p.grad += torch.randn_like(p.grad) * NOISE_MULTIPLIER

        optimizer.step()

dp_model.eval()
print("DP-SGD model trained")

In [None]:
def invert(model, target_digit, steps=800, lr=0.05):
    model.eval()
    fake_img = torch.randn(1, 1, 28, 28, device=device, requires_grad=True)
    opt = optim.Adam([fake_img], lr=lr)

    for _ in range(steps):
        opt.zero_grad()
        logits = model(fake_img)
        probs = F.softmax(logits, dim=1)
        loss = -torch.log(probs[0, target_digit] + 1e-8)
        loss.backward()
        opt.step()
        fake_img.data.clamp_(0, 1)

    return fake_img.detach(), probs[0, target_digit].item()

In [None]:
target_digit = 3

img_base, conf_base = invert(baseline_model, target_digit)
img_dp, conf_dp = invert(dp_model, target_digit)

print("Baseline confidence:", conf_base)
print("DP-SGD confidence:", conf_dp)

In [None]:
plt.figure(figsize=(6,3))

plt.subplot(1,2,1)
plt.imshow(img_base.cpu().squeeze(), cmap="gray")
plt.title(f"Baseline\nConf: {conf_base:.2f}")
plt.axis("off")

plt.subplot(1,2,2)
plt.imshow(img_dp.cpu().squeeze(), cmap="gray")
plt.title(f"DP-SGD\nConf: {conf_dp:.2f}")
plt.axis("off")

plt.show()