In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
import opacus
import numpy as np
import random

import matplotlib.pyplot as plt

from copy import deepcopy

In [2]:
from group_amplification.privacy_analysis.composition.pld.accounting import pld_tight_group, pld_traditional_group
from group_amplification.privacy_analysis.base_mechanisms import GaussianMechanism

In [3]:
# Define the model and the training and test steps
# The model uses convolutional neural network layers

class ConvNet(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

class LitMNIST(pl.LightningModule):
    def __init__(self, base_model, optimizer):
        super().__init__()
        self.base_model = base_model
        self.optimizer = optimizer
        self.loss_fn = nn.CrossEntropyLoss()


    def configure_optimizers(self):
        return self.optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.base_model(x)
        loss = self.loss_fn(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.base_model(x)
        loss = self.loss_fn(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        accuracy = (preds == y).float().mean()
        self.log('val_loss', loss)
        self.log('val_accuracy', accuracy, prog_bar=True)


In [4]:
# Code for reading the data, training the model

import opacus.data_loader
import opacus.optimizers


num_epochs = 8
num_workers = 4
batch_size_train = 64

torch.manual_seed(1)
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])
dataset = MNIST('/ceph/hdd/shared/schuchaj_MNIST', train=True, download=True, transform=transform)
train_dataset, val_dataset = random_split(dataset, [55000, 5000])
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=1000, num_workers=num_workers)

noise_multiplier = 0.6
max_grad_norm = 0.0001
subsampling_rate = batch_size_train / len(train_dataset)
num_iterations = len(train_loader) * num_epochs

In [None]:
deltas_tight = pld_tight_group(np.arange(11), GaussianMechanism(noise_multiplier), 
                         subsampling_rate, 0, 2, num_iterations,
                         {'value_discretization_interval': 1e-2})

In [None]:
deltas_traditional = pld_traditional_group(np.arange(11), GaussianMechanism(noise_multiplier), 
                         subsampling_rate, 2, num_iterations,
                         {'value_discretization_interval': 1e-2})

In [None]:
max_iterations_tight = (deltas_tight[8] < 1e-5).sum()
print(max_iterations_tight)

In [None]:
max_iterations_traditional = (deltas_traditional[8] < 1e-5).sum()
print(max_iterations_traditional)

In [None]:
seed = 42

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

base_model = ConvNet()
optimizer = torch.optim.Adam(base_model.parameters(), lr=1e-3)

base_model = opacus.grad_sample.GradSampleModule(base_model)
optimizer = opacus.optimizers.DPOptimizer(optimizer,
                                          expected_batch_size=batch_size_train,
                                          noise_multiplier=noise_multiplier,
                                          max_grad_norm=max_grad_norm)
train_loader = opacus.data_loader.DPDataLoader.from_data_loader(train_loader, distributed=False)


model = LitMNIST(base_model, optimizer)

trainer = pl.Trainer(max_epochs=num_epochs, max_steps=max_iterations_tight)
trainer.validate(model, val_loader)
trainer.fit(model, train_loader, val_loader)
trainer.validate(model, val_loader)

In [None]:
seed = 42

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

base_model = ConvNet()
optimizer = torch.optim.Adam(base_model.parameters(), lr=1e-3)

base_model = opacus.grad_sample.GradSampleModule(base_model)
optimizer = opacus.optimizers.DPOptimizer(optimizer,
                                          expected_batch_size=batch_size_train,
                                          noise_multiplier=noise_multiplier,
                                          max_grad_norm=max_grad_norm)
train_loader = opacus.data_loader.DPDataLoader.from_data_loader(train_loader, distributed=False)

model = LitMNIST(base_model, optimizer)

trainer = pl.Trainer(max_epochs=num_epochs, max_steps=max_iterations_traditional)
trainer.validate(model, val_loader)
trainer.fit(model, train_loader, val_loader)
trainer.validate(model, val_loader)

In [1]:
import seaborn as sns

In [2]:
save_dir = '/ceph/hdd/staff/schuchaj/group_amplification_plots/neurips24/mnist/half_page'

### Plot the losses over time

In [3]:
# Copy pasted from metrics.csv in lightning_logs N-1 and N, where N is most recent one

accs_tight = [0.09040000289678574,
              0.8776000142097473,
              0.8948000073432922,
              0.9020000100135803,
              0.9047999978065491,
              0.9010000824928284,
              0.9067999720573425,
              0.9104000926017761,
              0.9121999740600586]

accs_traditional = [0.09040000289678574]
accs_traditional += 8 * [0.7963999509811401]

In [4]:
sns.set_theme()

fig, ax = plt.subplots()

pal = sns.color_palette('colorblind', 2)

ax.plot(accs_traditional, marker='x', c=pal[0], label='Post-hoc', linestyle='dashed')
ax.plot(accs_tight, marker='x', c=pal[1], label='Specific')

ax.set_ylabel('Val. accuracy', fontsize=9)
ax.set_xlabel('Epoch', fontsize=9)

ax.set_ylim(0, 1)

ax.legend(loc='lower right')

### Plot privacy over time

In [None]:
sns.set_theme()

fig, ax = plt.subplots()

pal = sns.color_palette('colorblind', 2)

ax.plot(deltas_traditional[8], c=pal[0], label='Post-hoc', linestyle='dashed')
ax.plot(deltas_tight[8],  c=pal[1], label='Specific')
ax.plot(np.ones_like(deltas_tight[8]) * 1e-5,
         color='black',
         linestyle='dotted',
         label='Budget')

ax.plot()

ax.set_ylabel('ADP $\delta(\\varepsilon=8)$', fontsize=9)
ax.set_xlabel('Iteration t', fontsize=9)

ax.set_yscale('log')

ax.legend(loc='lower right')