In [242]:
!pip install pytorch-lightning loguru torchmetrics opacus

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [243]:
%load_ext autoreload
%autoreload 2

import warnings

warnings.filterwarnings("ignore")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [244]:
from torchvision.datasets import MNIST
import torchvision.transforms as tfs
from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset=MNIST(
        '.data',
        download=True,
        transform = tfs.Compose(
            [
                tfs.ToTensor(),
                tfs.Normalize((0.1307,), (0.3081,)),
            ]
        ),
    ),
    batch_size=8,
)

In [245]:
from loguru import logger
import os
import warnings

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.optimizer import Optimizer
import torchmetrics
from opacus import PrivacyEngine, GradSampleModule
from opacus.accountants import RDPAccountant
from opacus.data_loader import DPDataLoader
from opacus.optimizers import DPOptimizer
from opacus.lightning import DPLightningDataModule
from pytorch_lightning.utilities.cli import LightningCLI


class MyGradSampleModule(GradSampleModule):
    def forward(self, *args, **kwargs):
        logger.debug("...")
        return super().forward(*args, **kwargs)

    def add_hooks(self, *args, **kwargs):
        logger.debug(f"Adding hooks...")
        return super().add_hooks(*args, **kwargs)
    
    def remove_hooks(self, *args, **kwargs):
        logger.debug(f"Removing hooks...")
        return super().remove_hooks(*args, **kwargs)
    
    def disable_hooks(self, *args, **kwargs):
        logger.debug(f"Disabling hooks...")
        return super().disable_hooks(*args, **kwargs)
    
    def _close(self, *args, **kwargs):
        logger.debug(f"Closing...")
        return super()._close(*args, **kwargs)
    
    def capture_activations_hook(self, *args, **kwargs):
        logger.debug(f"activation hook")
        return super().capture_activations_hook(*args, **kwargs)
    
    def capture_backprops_hook(self, *args, **kwargs):
        logger.debug(f"backprops hook")
        return super().capture_backprops_hook(*args, **kwargs)

    def rearrange_grad_samples(self, *args, **kwargs):
        msg = ""
        if 'module' in kwargs and hasattr(kwargs.get('module'), 'activations'):
            msg += f"activations for module {kwargs.get('module').__class__.__name__} of type {type(kwargs.get('module'))}: {type(kwargs.get('module').activations)}"
        else:
            msg += f"no activations in module {kwargs.get('module').__class__.__name__} of type {type(kwargs.get('module'))}!"
        logger.debug(msg)
        return super().rearrange_grad_samples(*args, **kwargs)


class ManualModel(pl.LightningModule):
    def __init__(
        self,
        lr: float = 0.1,
        enable_dp: bool = True,
        delta: float = 1e-5,
        noise_multiplier: float = 1.0,
        max_grad_norm: float = 1.0,
    ):
        """A simple conv-net for classifying MNIST with differential privacy training
        Args:
            lr: Learning rate
            enable_dp: Enables training with privacy guarantees using Opacus (if True), vanilla SGD otherwise
            delta: Target delta for which (eps, delta)-DP is computed
            noise_multiplier: Noise multiplier
            max_grad_norm: Clip per-sample gradients to this norm
        """
        super().__init__()
        # Hyper-parameters
        self.lr = lr
        # Parameters
        self.model = torch.nn.Sequential(
            torch.nn.Conv2d(1, 16, 8, 2, padding=3),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 1), 
            torch.nn.Conv2d(16, 32, 4, 2), 
            torch.nn.ReLU(), 
            torch.nn.MaxPool2d(2, 1), 
            torch.nn.Flatten(), 
            torch.nn.Linear(32 * 4 * 4, 32), 
            torch.nn.ReLU(), 
            torch.nn.Linear(32, 10)
        )
        # Metrics
        self.test_accuracy = torchmetrics.Accuracy()
        # Differential privacy
        self.accountant = RDPAccountant()
        self.enable_dp = enable_dp
        self.delta = delta
        self.noise_multiplier = noise_multiplier
        self.max_grad_norm = max_grad_norm
        # Important: This property activates manual optimization.
        self.automatic_optimization = False

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        print("Configuring optimizers...")
        parameters = list(self.model.parameters())
        params1 = parameters[:3]
        params2 = parameters[3:]
        optimizers = [
            torch.optim.SGD(params1, lr=0.05),
            torch.optim.SGD(params2, lr=0.05),
        ]
        # privacy
        if not isinstance(self.model, GradSampleModule):
            self.model = MyGradSampleModule(self.model)
        data_loader = self.trainer._data_connector._train_dataloader_source.dataloader()
        sample_rate: float = 1 / len(data_loader)
        dataset_size: int = len(data_loader.dataset)  # type: ignore
        expected_batch_size = int(dataset_size * sample_rate)
        for i, optim in enumerate(optimizers):
            optim = DPOptimizer(
                optimizer=optim,
                noise_multiplier=self.noise_multiplier,
                max_grad_norm=1.0,
                expected_batch_size=expected_batch_size
            )
            optim.attach_step_hook(
                self.accountant.get_optimizer_hook_fn(sample_rate=sample_rate)
            )
            optimizers[i] = optim
        # return
        return optimizers

    def training_step(self, batch, batch_idx):
        optimizers = self.optimizers()
        if not isinstance(optimizers, (tuple,list)):
            optimizers = [optimizers]
        for optimizer_idx, optimizer in enumerate(optimizers):
            assert isinstance(optimizer, Optimizer)
            optimizer.zero_grad()
            logger.debug(f"Optimizer idx: {optimizer_idx}")
            loss = self.loss(batch)
            self.manual_backward(loss)
            optimizer.step()
        self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True)

    def on_train_epoch_end(self):
        # Logging privacy spent: (epsilon, delta)
        epsilon, best_alpha = self.accountant.get_privacy_spent(delta=self.delta)
        self.log("epsilon", epsilon, on_epoch=True, prog_bar=True)
        print(f"\nepsilon = {epsilon}; best_alpha = {best_alpha}")

    def loss(self, batch):
        logger.debug("Forward...")
        data, target = batch
        output = self(data)
        loss = F.cross_entropy(output, target)
        return loss

# ------------------
# Optim
# ------------------
    def on_train_batch_start(self, *args, **kwargs):
        logger.debug("Consuming batch...")
        return super().on_train_batch_start(*args, **kwargs)

    def on_before_zero_grad(self, optimizer):
        logger.debug(f"Using optimizer: {optimizer.__class__.__name__}")
        return super().on_before_zero_grad(optimizer)

    def optimizer_zero_grad(self, epoch, batch_idx, optimizer: Optimizer, optimizer_idx: int = None):
        logger.debug(f"[optimizer_idx:{optimizer_idx}] optimizer: {optimizer.__class__.__name__}")
        return super().optimizer_zero_grad(epoch, batch_idx, optimizer, optimizer_idx)

    def on_before_backward(self, *args, **kwargs):
        logger.debug(f"...")
        return super().on_before_backward(*args, **kwargs)

    def backward(self, loss: torch.Tensor, optimizer, optimizer_idx: int = None, *args, **kwargs):
        logger.debug(f"[optimizer_idx:{optimizer_idx}] optimizer: {optimizer.__class__.__name__} - loss: {loss}")
        return super().backward(loss, optimizer, optimizer_idx)
    
    def on_after_backward(self, *args, **kwargs):
        logger.debug("...")
        return super().on_after_backward(*args, **kwargs)

    def on_before_optimizer_step(self, optimizer, optimizer_idx: int = None):
        logger.debug(f"[optimizer_idx:{optimizer_idx}]")
        return super().on_before_optimizer_step(optimizer, optimizer_idx)
    
    def configure_gradient_clipping(self, *args, **kwargs):
        logger.debug("...")
        return super().configure_gradient_clipping(*args, **kwargs)

    def optimizer_step(self, epoch=None, batch_idx=None, optimizer: Optimizer = None, optimizer_idx: int = None, optimizer_closure=None, on_tpu=None, using_native_amp=None, using_lbfgs=None):
        logger.debug(f"[optimizer_idx:{optimizer_idx}] closure: {optimizer_closure}")
        return super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
        # optimizer.step(optimizer_closure)
        
    def on_train_batch_end(self, *args, **kwargs):
        logger.debug("...")
        return super().on_train_batch_end(*args, **kwargs)

In [246]:
class AutomaticModel(ManualModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.automatic_optimization = True

    def training_step(self, batch, batch_idx, optimizer_idx=None):
        logger.debug(f"Optimizer idx: {optimizer_idx}")
        loss = self.loss(batch)
        self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

In [247]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=False,
    max_steps=1,
    enable_model_summary=False,
)
trainer.fit(ManualModel(), train_dataloaders=dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
2022-10-18 16:06:28.776 | DEBUG    | __main__:add_hooks:26 - Adding hooks...


Configuring optimizers...


Training: 0it [00:00, ?it/s]

2022-10-18 16:06:28.864 | DEBUG    | __main__:on_train_batch_start:166 - Consuming batch...
2022-10-18 16:06:28.868 | DEBUG    | __main__:training_step:143 - Optimizer idx: 0
2022-10-18 16:06:28.871 | DEBUG    | __main__:loss:156 - Forward...
2022-10-18 16:06:28.874 | DEBUG    | __main__:forward:22 - ...
2022-10-18 16:06:28.878 | DEBUG    | __main__:capture_activations_hook:42 - activation hook
2022-10-18 16:06:28.882 | DEBUG    | __main__:capture_activations_hook:42 - activation hook
2022-10-18 16:06:28.886 | DEBUG    | __main__:capture_activations_hook:42 - activation hook
2022-10-18 16:06:28.888 | DEBUG    | __main__:capture_activations_hook:42 - activation hook
2022-10-18 16:06:28.893 | DEBUG    | __main__:on_before_backward:178 - ...
2022-10-18 16:06:28.894 | DEBUG    | __main__:backward:182 - [optimizer_idx:None] optimizer: NoneType - loss: 2.313441038131714
2022-10-18 16:06:28.896 | DEBUG    | __main__:capture_backprops_hook:46 - backprops hook
2022-10-18 16:06:28.898 | DEBUG   


epsilon = 0.4501099109865881; best_alpha = 18.0


In [248]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=False,
    max_steps=1,
    enable_model_summary=False,
)
trainer.fit(AutomaticModel(), train_dataloaders=dataloader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
2022-10-18 16:06:29.579 | DEBUG    | __main__:add_hooks:26 - Adding hooks...


Configuring optimizers...


Training: 0it [00:00, ?it/s]

2022-10-18 16:06:29.700 | DEBUG    | __main__:on_train_batch_start:166 - Consuming batch...
2022-10-18 16:06:29.704 | DEBUG    | __main__:optimizer_step:198 - [optimizer_idx:0] closure: <pytorch_lightning.loops.optimization.optimizer_loop.Closure object at 0x7fc26274cc50>
2022-10-18 16:06:29.708 | DEBUG    | __main__:training_step:7 - Optimizer idx: 0
2022-10-18 16:06:29.713 | DEBUG    | __main__:loss:156 - Forward...
2022-10-18 16:06:29.716 | DEBUG    | __main__:forward:22 - ...
2022-10-18 16:06:29.719 | DEBUG    | __main__:capture_activations_hook:42 - activation hook
2022-10-18 16:06:29.724 | DEBUG    | __main__:capture_activations_hook:42 - activation hook
2022-10-18 16:06:29.726 | DEBUG    | __main__:capture_activations_hook:42 - activation hook
2022-10-18 16:06:29.728 | DEBUG    | __main__:capture_activations_hook:42 - activation hook
2022-10-18 16:06:29.733 | DEBUG    | pytorch_lightning.loops.optimization.optimizer_loop:closure:139 - Calling <function OptimizerLoop._make_zero_g

ValueError: ignored

In [None]:
from pytorch_lightning.loops.optimization.optimizer_loop import Closure