<a href="https://colab.research.google.com/github/jusjusjus/noise-in-dpsgd-2020/blob/master/train_dpgan_in_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup Environment

In [None]:
!git clone https://github.com/jusjusjus/noise-in-dpsgd-2020.git

In [None]:
cd noise-in-dpsgd-2020

In [None]:
!nvidia-smi

# Build Classifier

In [None]:
from os import makedirs
from os.path import join, dirname

import torch
import numpy as np
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from ganlib.dataset import Dataset

from ganlib.classifier import Classifier

torch.manual_seed(42 * 42)

In [None]:
def evaluate(model, dataloader):
    model.eval()
    acc, examples_seen = 0.0, 0
    with torch.no_grad():
        for i, (examples, labels) in enumerate(dataloader):
            batch_size = labels.shape[0]
            examples = examples.to(device)
            labels = labels.to(device)
            logits = model(examples)
            y_pred = torch.argmax(logits, dim=-1)

            acc_i = (y_pred == labels).sum().item()
            acc = (examples_seen * acc + acc_i) / (examples_seen + batch_size)
            examples_seen += batch_size
        
    model.train()
    return 100 * acc

In [None]:
def schedule(lr, loss):
    return lr if loss > 1.0 else loss * lr

In [None]:
epochs = 10
batch_size = 128
lr_per_example = 1e-4
eval_every = 1000
adapt_every = 100
weight_decay = 0.001
best_model_filename = join("cache", "mnist_classifier.ckpt")
makedirs(dirname(best_model_filename), exist_ok=True)

learning_rate = batch_size * lr_per_example

print(f"learning rate: {learning_rate} (at {batch_size}-minibatches)")

trainset = Dataset(labels=True, train=True)
testset = Dataset(labels=True, train=False)
trainloader = DataLoader(trainset, batch_size=batch_size,
                         shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=batch_size,
                        shuffle=False, num_workers=4)

clf = Classifier()
clf = clf.cuda() if torch.cuda.is_available() else clf
clf.train()
device = next(clf.parameters()).device
print(device)

loss_op = nn.NLLLoss(reduction='mean')
optimizer = optim.Adam(clf.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [None]:
global_step, running_loss = 0, 1.0
best_acc = 2.0
for epoch in range(epochs):
    for i, (examples, labels) in enumerate(trainloader):
        batch_size = labels.shape[0]
        examples = examples.to(device)
        labels = labels.to(device)
        logits = clf(examples)
        loss = loss_op(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss = 0.99 * running_loss + 0.01 * loss.item()

        if global_step % adapt_every == 0:
            lr = schedule(learning_rate, running_loss)
            print(f"[{global_step}, epoch {epoch+1}] "
                  f"train loss = {running_loss:.3f}, "
                  f"new learning rate = {lr:.5f}")
            for g in optimizer.param_groups:
                g.update(lr=lr)

        if global_step % eval_every == 0:
            acc = evaluate(clf, testloader)
            print(f"[{global_step}, epoch {epoch+1}] "
                  f"train loss = {running_loss:.3f}, "
                  f"test acc = {acc:.1f}")

            if acc > best_acc:
                clf.to_checkpoint(best_model_filename)
                best_acc = acc

        global_step += 1

In [None]:
print("Running final evaluation")
acc = evaluate(clf, testloader)
print(f"[{global_step}, final evaluation] "
      f"train loss = {running_loss:.3f}, "
      f"test acc = {acc:.1f}")

if acc > best_acc:
    clf.to_checkpoint(best_model_filename)
    best_acc = acc

# Train GAN

In [None]:
from os.path import join

import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets

torch.manual_seed(42 * 42)

from ganlib import scripts
from ganlib.gan import GenerativeAdversarialNet
from ganlib.logger import Logger
from ganlib.privacy import compute_renyi_privacy
from ganlib.trainer import DPWGANGPTrainer, WGANGPTrainer
from ganlib.dataset import Dataset
from ganlib.generator import MNISTGenerator, Optimizable

cuda = torch.cuda.is_available()

In [None]:
class MNISTCritic(Optimizable):

    def __init__(self):
        super().__init__()
        kw = {'padding': 2, 'stride': 2, 'kernel_size': 5}
        C = capacity = 64
        self.activation = nn.LeakyReLU(negative_slope=0.2)
        self.conv1 = nn.Conv2d(1,     1 * C, **kw)
        self.conv2 = nn.Conv2d(1 * C, 2 * C, **kw)
        self.conv3 = nn.Conv2d(2 * C, 4 * C, **kw)
        self.flatten = nn.Flatten()
        self.projection = nn.Linear(4 * 4 * 4 * C, 1)

    def forward(self, images):
        images = self.activation(self.conv1(images))
        images = self.activation(self.conv2(images))
        images = self.activation(self.conv3(images))
        images = self.flatten(images)
        images = self.projection(images)
        criticism =  images.squeeze(-1)
        return criticism

In [None]:
def log(logger, info, tag, network, global_step):
    """print every 25, and plot every 250 steps network output"""
    if global_step % 25 == 0:
        logger.add_scalars(tag, info, global_step)
        s = f"[Step {global_step}] "
        s += ' '.join(f"{tag}/{k} = {v:.3g}" for k, v in info.items())
        print(s)

    if (global_step + 1) % 250 == 0:
        ckpt = logger.add_checkpoint(network, global_step)
        scripts.generate(logger=logger, params=ckpt,
                         step=global_step)
        if exists(join("cache", "mnist_classifier.ckpt")):
            scripts.inception(logger=logger, params=ckpt,
                              step=global_step)
        network.train()

In [None]:
# Set optional parameters
#
# nodp: set to true to train without differential privacy
# sigma: noise multiplier determining epsilon of DP
# grad_clip: per-example L2-gradient clipping constant
#
# Note: the log directory is adopted according to these parameters
# so that you can view the algorithm output for different parameter
# values side-by-side.

nodp = False
sigma = 0.5
grad_clip = 1.0

In [None]:
# Set some other default parameters

batch_size = 128
lr_per_example = 3.125e-6
delta = 1e-5
critic_steps = 4

# Process parameters

logdir = join('cache', 'logs')
logdir = join(logdir, 'nodp' if nodp else f"sigma_{sigma}-clip_{grad_clip}")
learning_rate = batch_size * lr_per_example

In [None]:
# Initialize generator and critic.  We wrap generator and critic into
# `GenerativeAdversarialNet` and provide methods `cuda` and `state_dict`

generator = MNISTGenerator()
critic = MNISTCritic()
gan = GenerativeAdversarialNet(generator, critic)
gan = gan.cuda() if cuda else gan

dset = Dataset(labels=False, train=True)
dataloader = DataLoader(dset, batch_size=batch_size,
                        shuffle=True, num_workers=4)

# Initialize optimization.  We make optimizers part of the network and provide
# methods `.zero_grad` and `.step` to simplify the code.

generator.init_optimizer(torch.optim.Adam, lr=learning_rate, betas=(0.5, 0.9))
critic.init_optimizer(torch.optim.Adam, lr=learning_rate, betas=(0.5, 0.9))

if nodp:
    trainer = WGANGPTrainer(batch_size=batch_size)
else:
    print("training with differential privacy")
    print(f"> delta = {delta}")
    print(f"> sigma = {sigma}")
    print(f"> L2-clip = {grad_clip}")
    trainer = DPWGANGPTrainer(sigma=sigma, l2_clip=grad_clip, batch_size=batch_size)

print(f"> learning rate = {learning_rate} (at {batch_size}-minibatches)")

In [None]:
logs = {}
global_step = 0
logger = Logger(logdir=logdir)
for epoch in range(100):
    for imgs in dataloader:

        if (global_step + 1) % critic_steps == 0:
            genlog = trainer.generator_step(gan)
            logs.update(**genlog)

        critlog = trainer.critic_step(gan, imgs)
        logs.update(**critlog)
        
        if not nodp:
            spent = compute_renyi_privacy(
                len(dset), batch_size, global_step + 1, sigma, delta)
            logs['epsilon'] = spent.eps

        log(logger, logs, 'train', gan, global_step)

        global_step += 1