<a href="https://colab.research.google.com/github/jusjusjus/noise-in-dpsgd-2020/blob/master/train_dpgan_in_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!git clone https://github.com/jusjusjus/noise-in-dpsgd-2020.git

Cloning into 'noise-in-dpsgd-2020'...
remote: Enumerating objects: 61, done.[K
remote: Counting objects: 100% (61/61), done.[K
remote: Compressing objects: 100% (45/45), done.[K
remote: Total 61 (delta 30), reused 47 (delta 16), pack-reused 0[K
Unpacking objects: 100% (61/61), done.


In [6]:
cd noise-in-dpsgd-2020

/content/noise-in-dpsgd-2020


In [3]:
!nvidia-smi

Mon Feb 17 13:11:58 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.48.02    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    25W / 250W |      0MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

In [7]:
from os import makedirs
from os.path import join, dirname

import torch
import numpy as np
from PIL import Image
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets

from ganlib.classifier import Classifier

torch.manual_seed(42 * 42)

class Dataset(datasets.MNIST):

    def __init__(self, *args, **kwargs):
        data_dir = join('cache', 'data')
        makedirs(data_dir, exist_ok=True)
        super().__init__(data_dir, *args, download=True, **kwargs)

    def __getitem__(self, i):
        img, labels = super().__getitem__(i)
        img = img.resize((28, 28), Image.ANTIALIAS)
        img = np.array(img)[None, ...]
        img = img.astype(np.float32) / 255.0
        img = 2 * img - 1
        return img, labels


def evaluate(model, dataloader):
    model.eval()
    acc, examples_seen = 0.0, 0
    with torch.no_grad():
        for i, (examples, labels) in enumerate(dataloader):
            batch_size = labels.shape[0]
            examples = examples.to(device)
            labels = labels.to(device)
            logits = model(examples)
            y_pred = torch.argmax(logits, dim=-1)

            acc_i = (y_pred == labels).sum().item()
            acc = (examples_seen * acc + acc_i) / (examples_seen + batch_size)
            examples_seen += batch_size
        
    model.train()
    return 100 * acc


def schedule(lr, loss):
    return lr if loss > 1.0 else loss * lr


epochs = 10
batch_size = 128
lr_per_example = 1e-4
eval_every = 1000
adapt_every = 100
weight_decay = 0.001
best_model_filename = join("cache", "mnist_classifier.ckpt")
makedirs(dirname(best_model_filename), exist_ok=True)

learning_rate = batch_size * lr_per_example

print(f"learning rate: {learning_rate} (at {batch_size}-minibatches)")

trainset =  Dataset(train=True)
testset =  Dataset(train=False)
trainloader = DataLoader(trainset, batch_size=batch_size,
                         shuffle=True, num_workers=4)
testloader = DataLoader(testset, batch_size=batch_size,
                        shuffle=False, num_workers=4)

clf = Classifier()
clf = clf.cuda() if torch.cuda.is_available() else clf
clf.train()
device = next(clf.parameters()).device
print(device)

loss_op = nn.NLLLoss(reduction='mean')
optimizer = optim.Adam(clf.parameters(), lr=learning_rate, weight_decay=weight_decay)

global_step, running_loss = 0, 1.0
best_acc = 2.0
for epoch in range(epochs):
    for i, (examples, labels) in enumerate(trainloader):
        batch_size = labels.shape[0]
        examples = examples.to(device)
        labels = labels.to(device)
        logits = clf(examples)
        loss = loss_op(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss = 0.99 * running_loss + 0.01 * loss.item()

        if global_step % adapt_every == 0:
            lr = schedule(learning_rate, running_loss)
            print(f"[{global_step}, epoch {epoch+1}] "
                  f"train loss = {running_loss:.3f}, "
                  f"new learning rate = {lr:.5f}")
            for g in optimizer.param_groups:
                g.update(lr=lr)

        if global_step % eval_every == 0:
            acc = evaluate(clf, testloader)
            print(f"[{global_step}, epoch {epoch+1}] "
                  f"train loss = {running_loss:.3f}, "
                  f"test acc = {acc:.1f}")

            if acc > best_acc:
                clf.to_checkpoint(best_model_filename)
                best_acc = acc

        global_step += 1

print("Running final evaluation")
acc = evaluate(clf, testloader)
print(f"[{global_step}, final evaluation] "
      f"train loss = {running_loss:.3f}, "
      f"test acc = {acc:.1f}")

if acc > best_acc:
    clf.to_checkpoint(best_model_filename)
    best_acc = acc

learning rate: 0.0128 (at 128-minibatches)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to cache/data/Dataset/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting cache/data/Dataset/raw/train-images-idx3-ubyte.gz to cache/data/Dataset/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to cache/data/Dataset/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting cache/data/Dataset/raw/train-labels-idx1-ubyte.gz to cache/data/Dataset/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to cache/data/Dataset/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting cache/data/Dataset/raw/t10k-images-idx3-ubyte.gz to cache/data/Dataset/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to cache/data/Dataset/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting cache/data/Dataset/raw/t10k-labels-idx1-ubyte.gz to cache/data/Dataset/raw
Processing...
Done!
cuda:0
[0, epoch 1] train loss = 1.015, new learning rate = 0.01280
[0, epoch 1] train loss = 1.015, test acc = 36.2
[100, epoch 1] train loss = 0.566, new learning rate = 0.00724
[200, epoch 1] train loss = 0.289, new learning rate = 0.00370
[300, epoch 1] train loss = 0.164, new learning rate = 0.00210
[400, epoch 1] train loss = 0.105, new learning rate = 0.00135
[500, epoch 2] train loss = 0.077, new learning rate = 0.00099
[600, epoch 2] train loss = 0.060, new learning rate = 0.00077
[700, epoch 2] train loss = 0.054, new learning rate = 0.00069
[800, epoch 2] train loss = 0.049, new learning rate = 0.00062
[900, epoch 2] train loss = 0.046, new learning rate = 0.00059
[1000, epoch 3] train loss = 0.043, new learning rate = 0.00055
[1000, epoch 3] train loss = 0.043, test acc = 99.1
[1100, epoch 3] train loss = 0.040, new learning rate = 0.00051
[1200, epoch 3] train loss = 