In [1]:
from deepee import (PrivacyWrapper, PrivacyWatchdog, UniformDataLoader,
                     ModelSurgeon, SurgicalProcedures)
import numpy as np
import torch
from torch import nn
from torchvision import datasets, transforms
from torchvision.transforms import functional as F
from matplotlib import pyplot as plt
import cv2
from skimage.restoration import denoise_wavelet

batch_size = 200
test_batch_size = 200
log_interval = 1000
epochs = 5
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
train_loader = UniformDataLoader(
    
    datasets.MNIST(
        "../data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        ),
    ),
    batch_size=batch_size,
)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data",
        train=False,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        ),
    ),
    batch_size=test_batch_size,
    shuffle=True,
) 

In [11]:
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.bn1 = nn.BatchNorm1d(256, track_running_stats=False)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = torch.sigmoid(self.fc1(x))
        x = self.bn1(x)
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)
        return x

In [12]:
def wavelet_denoiser(model, sigma):
    model.wrapped_model.fc1.weight.grad = torch.tensor(
        denoise_wavelet(np.array(model.wrapped_model.fc1.weight.grad), sigma, rescale_sigma=True)
    )
    model.wrapped_model.fc2.weight.grad = torch.tensor(
        denoise_wavelet(np.array(model.wrapped_model.fc2.weight.grad), sigma, rescale_sigma=True)
    )
    model.wrapped_model.fc3.weight.grad = torch.tensor(
        denoise_wavelet(np.array(model.wrapped_model.fc3.weight.grad), sigma, rescale_sigma=True)
    )
       
def wavelet_denoiser_flatten(model, sigma, mode):
    size_fc1 = model.wrapped_model.fc1.weight.grad.size()
    model.wrapped_model.fc1.weight.grad = torch.tensor(
        denoise_wavelet(
            np.array(
                torch.flatten(
                    model.wrapped_model.fc1.weight.grad
                )
            ),
            sigma,
            rescale_sigma=True,
            method = "VisuShrink"
        )
    ).unflatten(0, size_fc1)
    
    size_fc2 = model.wrapped_model.fc2.weight.grad.size()
    model.wrapped_model.fc2.weight.grad = torch.tensor(
        denoise_wavelet(
            np.array(
                torch.flatten(
                    model.wrapped_model.fc2.weight.grad
                )
            ),
            sigma,
            rescale_sigma=True,
            method = "VisuShrink"
        )
    ).unflatten(0, size_fc2)
    
    size_fc3 = model.wrapped_model.fc3.weight.grad.size()
    model.wrapped_model.fc3.weight.grad = torch.tensor(
        denoise_wavelet(
            np.array(
                torch.flatten(
                    model.wrapped_model.fc3.weight.grad
                )
            ),
            sigma,
            rescale_sigma=True,
            method = "VisuShrink"
        )
    ).unflatten(0, size_fc3)
        
def plot_grad_hist(model, layer, rows, columns, i, title="Gradient histogram", y_lim=7500):
    bins = 50
    x_range = (-0.02, 0.02)
    
    fig.add_subplot(rows, columns, i)
    if layer == 1:
        plt.hist(
            np.transpose(np.array(torch.flatten(model.fc1.weight.grad))), 
            bins=bins, 
            range=x_range
        )
    elif layer == 2:
        plt.hist(
            np.transpose(np.array(torch.flatten(model.fc2.weight.grad))), 
            bins=bins, 
            range=x_range
        )
    elif layer == 3:
        plt.hist(
            np.transpose(np.array(torch.flatten(model.fc3.weight.grad))), 
            bins=bins, 
            range=x_range
        )
    plt.title(title)
    plt.ylim(0, y_lim)
    
def plot_grad_image(model, layer, rows, columns, i, title="Gradient image"):
    fig.add_subplot(rows, columns, i)
    if layer == 1:
        plt.imshow(F.to_pil_image(model.fc1.weight.grad))
    elif layer == 2:
        plt.imshow(F.to_pil_image(model.fc2.weight.grad))
    elif layer == 3:
        plt.imshow(F.to_pil_image(model.fc3.weight.grad))
    plt.title(title)
    plt.axis("off")

In [None]:
watchdog = PrivacyWatchdog(
    train_loader,
    target_epsilon=1.0,
    abort=False,
    target_delta=1e-5,
    fallback_to_rdp=False,
)
model = PrivacyWrapper(SimpleNet(), batch_size, 1.0, 1.0, watchdog=watchdog).to(
    device
)
optimizer = torch.optim.SGD(model.wrapped_model.parameters(), lr=0.1)

surgeon = ModelSurgeon(SurgicalProcedures.BN_to_GN)
model = surgeon.operate(model)

sigma = 0.002
plot_layer = 2
y_lim = 8000
rows = 6
columns = 3
i = 1
fig = plt.figure(figsize=(columns*7, rows*4))

# Train
for epoch in range(num_epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = torch.nn.CrossEntropyLoss()(output, target)
        loss.backward()
        
        if epoch == 4 and (batch_idx == 50 or batch_idx == 100 or batch_idx == 150):
            grad = torch.flatten(model.wrapped_model.fc2.weight.grad).detach().clone()
            
            model.clip_and_accumulate()
            
            clipped_grad = torch.flatten(model.wrapped_model.fc2.weight.grad).detach().clone()
            num_grad = len(clipped_grad)
            
            title = "Clipped (Epoch: {}, Batch: {})".format(epoch, batch_idx)
            plot_grad_hist(model.wrapped_model, plot_layer, rows, columns, i, title, y_lim)
            plot_grad_image(model.wrapped_model, plot_layer, rows, columns, i+3, title)

            model.noise_gradient()

            l2_clipped = sum(abs(clipped_grad**2 - torch.flatten(model.wrapped_model.fc2.weight.grad)**2)**(1/2))
            l2 = sum(abs(grad**2 - torch.flatten(model.wrapped_model.fc2.weight.grad)**2)**(1/2))
            title = "Noisy (Epoch: {}, Batch: {}, L2 clipped: {:.0f}, L2: {:.0f})".format(epoch, batch_idx, l2_clipped, l2)
            plot_grad_hist(model.wrapped_model, plot_layer, rows, columns, i+1, title, y_lim)
            plot_grad_image(model.wrapped_model, plot_layer, rows, columns, i+4, title)
            
            wavelet_denoiser(model, sigma)
            
            l2_clipped = sum(abs(clipped_grad**2 - torch.flatten(model.wrapped_model.fc2.weight.grad)**2)**(1/2))
            l2 = sum(abs(grad**2 - torch.flatten(model.wrapped_model.fc2.weight.grad)**2)**(1/2))
            title = "Wavelet (Epoch: {}, Batch: {}, L2 clipped: {:.0f}, L2: {:.0f})".format(epoch, batch_idx, l2_clipped, l2)
            plot_grad_hist(model.wrapped_model, plot_layer, rows, columns, i+2, title, y_lim)
            plot_grad_image(model.wrapped_model, plot_layer, rows, columns, i+5, title)
            
            i = i+6
        else:
            model.clip_and_accumulate()
            model.noise_gradient()
            wavelet_denoiser(model, sigma)

        optimizer.step()
        model.prepare_next_batch()
        if batch_idx % log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )

    # Test
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += torch.nn.CrossEntropyLoss(reduction="sum")(
                output, target
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
            test_loss,
            correct,
            len(test_loader.dataset),
            100.0 * correct / len(test_loader.dataset),
        )
    )





INFO:root:Privacy spent at 100 steps: 0.18
INFO:root:Privacy spent at 200 steps: 0.27
INFO:root:Privacy spent at 300 steps: 0.34



Test set: Average loss: 0.6390, Accuracy: 8337/10000 (83%)


INFO:root:Privacy spent at 400 steps: 0.39
INFO:root:Privacy spent at 500 steps: 0.44
INFO:root:Privacy spent at 600 steps: 0.49



Test set: Average loss: 0.4098, Accuracy: 8975/10000 (90%)


INFO:root:Privacy spent at 700 steps: 0.53
INFO:root:Privacy spent at 800 steps: 0.57
INFO:root:Privacy spent at 900 steps: 0.61



Test set: Average loss: 0.3666, Accuracy: 9138/10000 (91%)


INFO:root:Privacy spent at 1000 steps: 0.65
INFO:root:Privacy spent at 1100 steps: 0.68
INFO:root:Privacy spent at 1200 steps: 0.72



Test set: Average loss: 0.3085, Accuracy: 9251/10000 (93%)


INFO:root:Privacy spent at 1300 steps: 0.75
INFO:root:Privacy spent at 1400 steps: 0.78
