In [1]:
import imp
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import torchvision
from torch.utils.data import DataLoader, Dataset
from opacus import PrivacyEngine
import torch.nn.functional as F
from operations import *
from genotypes import PRIMITIVES
from genotypes import Genotype
from opacus.grad_sample import GradSampleModule, register_grad_sampler
from opacus.utils.batch_memory_manager import BatchMemoryManager
from typing import Dict
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_iris
from copy import deepcopy

  warn(f"Failed to load image Python extension: {e}")


# Finding sample-wise Gradients for Supernet

In this notebook we derive how we come up with sample-wise gradients for our supernet we use to perform differentiable NAS. For this we note that the supernet is based on a convex, weighted combination of operations which are all applied to the same input, that is each _mixed operation_ is defined as follows:
\begin{equation}
    m = \sum_{o \in O} \alpha_o \cdot o(x)
\end{equation}
Each $o$ is a convolution/pooling operation or a regular neural network, thus opacus already knows how to compute sample-based gradients for all parameters of each $o$. Thus, with a smart design of our supernet-architecture we can avoid the computation of sample-wise gradients for all the operations we have in use and pass the heavy lifting to opacus. The problem then reduces to providing opacus with sample-wise gradients w.r.t $\alpha_o$ for each $o$.

For this, let's see how we compute the gradients of an arbitrary loss w.r.t. the alpha-parameters in a simple setup: We only have 3 operations, each associated with a certaing weight $\alpha_o$. The mixed operation is then followed by a linear transformation producing the output, thus the network reads:
\begin{equation}
    \hat{y} = \bigg(\sum_{o \in O} \alpha_o \cdot o(x) \bigg) \cdot \mathbf{W}
\end{equation}

The following shwos the forward pass of such a network:

In [194]:
X = [torch.randn(1, 5) for _ in range(0, 4)]
alphas = nn.Parameter(torch.ones(4) / 4, requires_grad=True)
W = nn.Parameter(torch.randn(5, 1), requires_grad=True)

In [195]:
softmaxed_alphas = torch.softmax(alphas, dim=0)
softmaxed_alphas

tensor([0.2500, 0.2500, 0.2500, 0.2500], grad_fn=<SoftmaxBackward0>)

In [196]:
mop = sum(alphas[i] * X[i] for i in range(0, 4))
y = torch.matmul(mop, W)
y

tensor([[1.1655]], grad_fn=<MmBackward0>)

Let's proceed with computing a loss and updating the parameters. To keep things easy, let's pretend our loss is just the sum of all elements in our output. We then can compute the gradients w.r.t. each $\alpha_o$ easily by calling the backward-method:

In [197]:
l = torch.sum(y)
l.backward()
alphas.grad

tensor([-1.5030,  3.3426, -0.0513,  2.8738])

The gradient can be expressed as follows:
\begin{equation}
    \frac{\partial \ell}{\partial \alpha_o} = \sum_{i=1}^n \frac{\partial \ell}{\partial z^{(0)}_k} \cdot \frac{\partial z^{(0)}_k}{\partial \alpha_o}
\end{equation}
Here we sum over all gradients of our $n$ outputs w.r.t. $\alpha_o$. $z^{(0)}_k$ denotes the $k$-th element of the last layer (output). Expanding this further yields:
\begin{align}
    \frac{\partial \ell}{\partial \alpha_o} = & \sum_{i=1}^n \frac{\partial \ell}{\partial z^{(0)}_k} \cdot \sum_{d=1}^{|L_1|} \mathbf{W}^{(1)}_{d k} \cdot \frac{\partial z_d^{(1)}}{\partial \alpha_o} \\
    & \sum_{i=1}^n \frac{\partial \ell}{\partial z^{(0)}_k} \cdot \sum_{d=1}^{|L_1|} \mathbf{W}^{(1)}_{d k} \cdot \frac{\partial \sum_{o \in O} \big( \alpha_o o(z^{(2)}) \big)}{\partial \alpha_o} \\
    &  \sum_{i=1}^n \frac{\partial \ell}{\partial z^{(0)}_k} \cdot \sum_{d=1}^{|L_1|} \mathbf{W}^{(1)}_{d k} \cdot o(z^{(2)})_d
\end{align}
Since $\frac{\partial \ell}{\partial z^{(0)}_k} = 1$ we obtain:
\begin{align}
    \frac{\partial \ell}{\partial \alpha_o} & = \sum_{k=1}^n \sum_{d=1}^{|L_1|} \mathbf{W}_{d k}^{(1)} \cdot o(z^{(2)})_d \\
\end{align}
This reduces to:
\begin{equation}
    \frac{\partial \ell}{\partial \alpha_o} = \sum_{k=1}^n \big(\mathbf{W}^{(1)^T}\big)_k \cdot o(z^{(2)})
\end{equation}

In [198]:
X = torch.cat(X, dim=0)
1 * X.matmul(W) # 1 = derivative w.r.t. the output of mixed operation, rest as derived above

tensor([[-1.5030],
        [ 3.3426],
        [-0.0513],
        [ 2.8738]], grad_fn=<MulBackward0>)

## Extending the Minimal Example
Now we will start extending the above approach. As you can see computing the gradients and deriving the forumlas can get cumbersome really quickly. That's why we should use opacus' capabilities of computing sample-wise gradients for as many modules as possible. Below we define 3 operations that will be used withing our NAS-approach. For these modules opacus already knows how to compute sample-wise gradients, thus we can just go ahead and use them as they are.

In [2]:
class Op1(nn.Module):

    def __init__(self, in_dim, h_dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, out_dim)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return x

class Op2(nn.Module):

    def __init__(self, in_dim, h_dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, out_dim)
        
    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x

class Op3(nn.Module):

    def __init__(self, in_dim, h_dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, out_dim)
        
    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        return x

PRIMS = [Op1, Op2, Op3]

The tricky part is the mixed operation-module. If we would use plain pytorch, we just could go ahead and build one `MixedOp` module taking care of calling each operation and compute the convex combination of the operation's output. However, since opacus does not know the MixedOp-module it cannot compute the sample-wise gradients w.r.t. the alpha-parameters. Thus we have to tell opacus how to compute the gradients. For this opacus provides the activations (input to our module) and the gradients w.r.t. the outputs of our module. If we would go for a "plain-pytorch approach" this would require us to compute the gradients w.r.t. alpha-parameters and the model-parameters of each operation "by hand". Since this is cumbersome and not error-prone, we split up the MixedOp into two parts: One `ParallelOp` which does nothing but applying each operation on the same input data and a `MixedOp` (please don't get confused by the naming) which just cares about computing the convex combination of the oerpation's output computed by the ParallelOp. 

This way we can compute the gradients w.r.t. the alphas easily by just applying the same reasoning as above. The gradients can then be computed by computing the vector-product of the activations and the gradients w.r.t. the MixedOp (which are both provided by opacus). Mathematically this can be expressed as follows assuming we have an $n \times i$-dimensional real activation matrix $\mathbf{a}$ and an $i$-dimensional real vector $\nabla \mathbf{m}$ representing the gradients w.r.t. our mixed operation. We have an activation of $n \times i$ because we compute the convex combination of $n$ operations, each producing outputs of dimension $i$. Since each of the $n$ operations is associated with a weight $\alpha_j$, we aim to compute the gradient w.r.t. each of the $n$ weights, thus we aim to obtain a $n$-dimensional gradient vector for one sample and a $B \times n$-dimensional gradient matrix for a batch of size $B$. We can easily compute this using an einsum:
\begin{equation}
    \nabla_{j} \alpha_m = \sum_{k=1}^i \nabla \mathbf{m}_i \cdot \mathbf{a}_{m k}
\end{equation}
Here $\nabla_j \alpha_m$ is the $j$-th element of the gradient vector w.r.t. to the alphas associated with operation $m$. As we can see this is just a more general version of the equation we've derived above.

But this is not all we have to do: Remember we have the ParallelOp computing all the operation's outputs. This module does not have any parameters, thus there is nothing we can update and so there are also no gradients for this module. Thus we can tell opacus that there is nothing to compute in the backward pass. Details on the implementation can be obtained below:

In [3]:
class ParallelOp(nn.Module):
    
    def __init__(self, in_dim, h_dim, out_dim) -> None:
        super(ParallelOp, self).__init__()
        self._ops = nn.ModuleList()
        for primitive in PRIMS:
            self._ops.append(primitive(in_dim, h_dim, out_dim))

    def forward(self, x):
        operation_outs = []
        for op in self._ops:
            out = op(x)
            operation_outs.append(out)
        return torch.stack(operation_outs)

In [4]:
class MixedOp(nn.Module):

    def __init__(self):
        super(MixedOp, self).__init__()
        self.alphas = nn.Parameter(torch.zeros(len(PRIMS)), requires_grad=True)

    def forward(self, x):
        weights = torch.softmax(self.alphas, 0)
        return sum(w * op_out for w, op_out in zip(weights, x))

    def arch_params(self):
        return [self.alphas]

In [5]:
class LastLayer(nn.Module):

    def __init__(self, in_dim, hdim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hdim)
        self.fc2 = nn.Linear(hdim, out_dim)
        
    def forward(self, x):
        x = torch.hstack(x)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

In [6]:
mixed_op = MixedOp()
net = nn.Sequential(ParallelOp(28*28, 512, 256), mixed_op, LastLayer())
optim = torch.optim.SGD(mixed_op.arch_params(), 0.01)
loss = torch.nn.CrossEntropyLoss()
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0,), (1,))])
train_data = torchvision.datasets.FashionMNIST('../../datasets/femnist/', download=True, train=True, transform=transform)
val_data = torchvision.datasets.FashionMNIST('../../datasets/femnist/', download=True, train=False, transform=transform)
train_loader = DataLoader(train_data, 64)
val_loader = DataLoader(val_data, 64)

@register_grad_sampler(ParallelOp)
def grad_sampler_parallel_op(layer: MixedOp, activations: torch.Tensor, backprops: torch.Tensor):
    return {}

@register_grad_sampler(MixedOp)
def grad_sampler_mixed_op(layer: MixedOp, activations: torch.Tensor, backprops: torch.Tensor):
    grad = torch.einsum('nbi,bi->nb', activations, backprops)
    ret = {
        layer.alphas: grad
    }
    return ret

pe = PrivacyEngine()
netc = deepcopy(net)
net_, optim_, train_loader_ = pe.make_private(module=net, optimizer=optim, data_loader=train_loader, noise_multiplier=1., max_grad_norm=1.)
x_first, y_first = next(iter(train_loader_))
x_first = x_first.reshape((x_first.shape[0], 28*28))

TypeError: __init__() missing 3 required positional arguments: 'in_dim', 'hdim', and 'out_dim'

In [None]:
x_first = torch.randn((64, 784))
y_first = torch.ones(64, dtype=torch.long)

In [12]:
for pnetc, pnet_ in zip(netc.parameters(), net_.parameters()):
    print(torch.all(pnet_.data == pnetc.data))

NameError: name 'netc' is not defined

In [46]:
y_pred = net_(x_first)
l = loss(y_pred, y_first)
l.backward()



In [47]:
y_pred = netc(x_first)
l = loss(y_pred, y_first)
l.backward()

In [48]:
for pnetc, pnet_ in zip(netc.parameters(), net_.parameters()):
    print(torch.all(pnetc.grad.data == pnet_.grad.data))

tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)


> NOTE: Does this make sense? 
> 
> Yes it does since the optimizer performs the clipping and adds noise, thus the gradients are the same

## Building up a Cell
Now that we know how to build MixedOps with minimal effort, we can use the MixedOps in order to build up cells and we will use the cells in turn to build up our architecture search space. Since cells don't introduce any additional parameters, building a cell based on a set of MixedOps shoud be straightforward.

In [6]:
class Cell(nn.Module):
    # TODO: Do the same as above and check that gradients are correct!
    def __init__(self) -> None:
        super().__init__()
        self.nodes = nn.ModuleList()
        # initialize stem modules with some conv-operation self.stem0, self.stem1 = 
        curr_in_dim = 28*28
        curr_hdim = int(0.75 * curr_in_dim)
        curr_out_dim = int(0.75 * curr_hdim)
        dims = [curr_out_dim]
        for i in range(5):
            if i == 0:
                mop = nn.Sequential(ParallelOp(curr_in_dim, curr_hdim, curr_out_dim), MixedOp())
            else:
                in_dim = sum(dims)
                hdim = int(0.75*in_dim)
                out_dim = int(0.75*hdim)
                dims.append(out_dim)
                mop = nn.Sequential(ParallelOp(in_dim, hdim, out_dim), MixedOp())
            self.nodes.append(mop)

        
    def forward(self, x):
        inp = [self.nodes[0](x)]
        for op in self.nodes[1:]:
            print(torch.hstack(inp).shape)
            out = op(torch.hstack(inp))
            inp.append(out)
        
        return inp

## Building a Supernet
Now that we know how to build up a cell we can proceed and glue several cells to one network.

In [7]:
class Network(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.cell1 = Cell()
        self.linear1 = nn.Linear(2623, 28*28)
        self.cell2 = Cell()
        self.out = LastLayer(2623, 256, 10)

    def forward(self, x):
        x = self.cell1(x)
        x = torch.hstack(x)
        x = self.linear1(x)
        x = self.cell2(x)
        return self.out(x)

In [8]:
def get_params(net: nn.Module, param_type='arch'):
    parameters = []
    for name, param in net.named_parameters():
        if param_type == 'arch':
            if 'alphas' in name:
                parameters.append(param)
        elif param_type == 'model':
            if 'alphas' not in name:
                parameters.append(param)
        else:
            raise ValueError('Unsupported parameter type, must be either arch or model')
    return parameters

In [9]:
#net = nn.Sequential(Cell(), LastLayer(2623, 256, 10))
net = Network()
params = get_params(net, 'arch')
optim = torch.optim.SGD(params, 0.01)
loss = torch.nn.CrossEntropyLoss()
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0,), (1,))])
train_data = torchvision.datasets.FashionMNIST('../../datasets/femnist/', download=True, train=True, transform=transform)
val_data = torchvision.datasets.FashionMNIST('../../datasets/femnist/', download=True, train=False, transform=transform)
train_loader = DataLoader(train_data, 64)
val_loader = DataLoader(val_data, 64)

@register_grad_sampler(ParallelOp)
def grad_sampler_parallel_op(layer: MixedOp, activations: torch.Tensor, backprops: torch.Tensor):
    return {}

@register_grad_sampler(MixedOp)
def grad_sampler_mixed_op(layer: MixedOp, activations: torch.Tensor, backprops: torch.Tensor):
    grad = torch.einsum('nbi,bi->nb', activations, backprops)
    ret = {
        layer.alphas: grad
    }
    return ret

pe = PrivacyEngine()
netc = deepcopy(net)
net_, optim_, train_loader_ = pe.make_private(module=net, optimizer=optim, data_loader=train_loader, noise_multiplier=1., max_grad_norm=1.)
x_first, y_first = next(iter(train_loader_))
x_first = x_first.reshape((x_first.shape[0], 28*28))



In [15]:
x_first = torch.randn((64, 784))
y_first = torch.ones(64, dtype=torch.long)

In [18]:
y_pred = net_(x_first)
l = loss(y_pred, y_first)
l.backward()

y_pred = netc(x_first)
l = loss(y_pred, y_first)
l.backward()



torch.Size([64, 441])
torch.Size([64, 688])
torch.Size([64, 1075])
torch.Size([64, 1679])
torch.Size([64, 441])
torch.Size([64, 688])
torch.Size([64, 1075])
torch.Size([64, 1679])
torch.Size([64, 441])
torch.Size([64, 688])
torch.Size([64, 1075])
torch.Size([64, 1679])
torch.Size([64, 441])
torch.Size([64, 688])
torch.Size([64, 1075])
torch.Size([64, 1679])


In [19]:
for pnetc, pnet_ in zip(netc.parameters(), net_.parameters()):
    print(torch.all(pnetc.grad.data == pnet_.grad.data))

tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)

# Transfer to CNNs
In the above example we have used standard MLPs as our operations. However, our ultimate goal is to perform image classification, thus our operation-space consists of convolution- and pooling-operations. This requires us to adapt the computation of the gradients slightly.

In [206]:
class MixedOp(nn.Module):

  def __init__(self, C, stride):
    super(MixedOp, self).__init__()
    self._ops = nn.ModuleList()
    for primitive in PRIMITIVES:
      op = OPS[primitive](C, stride, False)
      if 'pool' in primitive:
        op = nn.Sequential(op, nn.GroupNorm(num_groups=1, num_channels=C, affine=False))
      self._ops.append(op)

  def forward(self, x, weights):
    return sum(w * op(x) for w, op in zip(weights, self._ops))


class Cell(nn.Module):

  def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
    super(Cell, self).__init__()
    self.reduction = reduction

    if reduction_prev:
      self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
    else:
      self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
    self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
    self._steps = steps
    self._multiplier = multiplier

    self._ops = nn.ModuleList()
    self._bns = nn.ModuleList()
    for i in range(self._steps):
      for j in range(2+i):
        stride = 2 if reduction and j < 2 else 1
        op = MixedOp(C, stride)
        self._ops.append(op)

  def forward(self, s0, s1, weights):
    s0 = self.preprocess0(s0)
    s1 = self.preprocess1(s1)

    states = [s0, s1]
    offset = 0
    for i in range(self._steps):
      s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))
      offset += len(states)
      states.append(s)

    return torch.cat(states[-self._multiplier:], dim=1)


class Network(nn.Module):

  def __init__(self, C, num_classes, layers, criterion, device, in_channels=3, steps=4, multiplier=4, stem_multiplier=3):
    super(Network, self).__init__()
    self._C = C
    self._num_classes = num_classes
    self._layers = layers
    self._criterion = criterion
    self._steps = steps
    self._multiplier = multiplier
    self.device = device

    C_curr = stem_multiplier*C
    self.stem = nn.Sequential(
      nn.Conv2d(in_channels, C_curr, 3, padding=1, bias=False),
      nn.GroupNorm(num_groups=1, num_channels=C_curr),
    )
 
    C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
    self.cells = nn.ModuleList()
    reduction_prev = False
    for i in range(layers):
      if i in [layers//3, 2*layers//3]:
        C_curr *= 2
        reduction = True
      else:
        reduction = False
      cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
      reduction_prev = reduction
      self.cells += [cell]
      C_prev_prev, C_prev = C_prev, multiplier*C_curr

    self.global_pooling = nn.AdaptiveAvgPool2d(1)
    self.classifier = nn.Linear(C_prev, num_classes)

    self._initialize_alphas()

  def new(self):
    model_new = Network(self._C, self._num_classes, self._layers, self._criterion, self.device).to(self.device)
    for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
        x.data.copy_(y.data)
    return model_new

  def forward(self, input):
    s0 = s1 = self.stem(input)
    for i, cell in enumerate(self.cells):
      if cell.reduction:
        weights = F.softmax(self.alphas_reduce, dim=-1)
      else:
        weights = F.softmax(self.alphas_normal, dim=-1)
      s0, s1 = s1, cell(s0, s1, weights)
    out = self.global_pooling(s1)
    logits = self.classifier(out.view(out.size(0),-1))
    return logits

  def _initialize_alphas(self):
    k = sum(1 for i in range(self._steps) for n in range(2+i))
    num_ops = len(PRIMITIVES)

    self.alphas_normal = nn.Parameter(1e-3*torch.zeros((k, num_ops)).to(self.device), requires_grad=True)
    self.alphas_reduce = nn.Parameter(1e-3*torch.zeros((k, num_ops)).to(self.device), requires_grad=True)
    self._arch_parameters = [
      self.alphas_normal,
      self.alphas_reduce,
    ]

  def arch_parameters(self):
    return self._arch_parameters

  def genotype(self):

    def _parse(weights):
      gene = []
      n = 2
      start = 0
      for i in range(self._steps):
        end = start + n
        W = weights[start:end].copy()
        edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
        for j in edges:
          k_best = None
          for k in range(len(W[j])):
            if k != PRIMITIVES.index('none'):
              if k_best is None or W[j][k] > W[j][k_best]:
                k_best = k
          gene.append((PRIMITIVES[k_best], j))
        start = end
        n += 1
      return gene

    gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy())
    gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy())

    concat = range(2+self._steps-self._multiplier, self._steps+2)
    genotype = Genotype(
      normal=gene_normal, normal_concat=concat,
      reduce=gene_reduce, reduce_concat=concat
    )
    return genotype

In [207]:
criterion = nn.CrossEntropyLoss()
device = torch.device('cpu')
model = Network(16, 10, 4, criterion, device, in_channels=1) # Cell(4, 3, 16, 36, 48, False, False)
optim_arch = torch.optim.SGD(get_params(model, 'arch'), 0.01)
optim_model = torch.optim.SGD(get_params(model, 'model'), 0.01)
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0,), (1,))])
train_data = torchvision.datasets.FashionMNIST('../../datasets/femnist/', download=True, train=True, transform=transform)
val_data = torchvision.datasets.FashionMNIST('../../datasets/femnist/', download=True, train=False, transform=transform)
train_loader = DataLoader(train_data, 64)
val_loader = DataLoader(val_data, 64)

In [208]:
# initialize all weights to zero
for param in get_params(model, 'model'):
    param.data.fill_(0.1)

for param in get_params(model, 'model'):
    print(torch.all(param.data == 0.1))

for param in get_params(model, 'arch'):
    print(torch.all(param.data == 0))

tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)

In [236]:
class ParallelOp(nn.Module):

  def __init__(self, C, stride) -> None:
    super().__init__()
    self._ops = nn.ModuleList()
    for primitive in PRIMITIVES:
      op = OPS[primitive](C, stride, False)
      if 'pool' in primitive:
        op = nn.Sequential(op, nn.GroupNorm(num_groups=1, num_channels=C, affine=False))
      self._ops.append(op)

  def forward(self, x):
      operation_outs = []
      for op in self._ops:
          out = op(x)
          operation_outs.append(out)
      return torch.stack(operation_outs)

class MixedOp(nn.Module):

    def __init__(self):
        super(MixedOp, self).__init__()
        self.alphas = nn.Parameter(torch.zeros(len(PRIMITIVES)), requires_grad=True)

    def forward(self, x):
      weights = torch.softmax(self.alphas, 0)
      return sum(w * op_out for w, op_out in zip(weights, x))


class Cell(nn.Module):

  def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev, mixed_ops_normal, mixed_ops_reduce):
    super(Cell, self).__init__()
    self.reduction = reduction

    if reduction_prev:
      self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
    else:
      self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
    self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
    self._steps = steps
    self._multiplier = multiplier

    self._ops = nn.ModuleList()
    self._bns = nn.ModuleList()

    mixed_op_idx = 0
    for i in range(self._steps):
      for j in range(2+i):
        stride = 2 if reduction and j < 2 else 1
        if reduction:
          op = nn.Sequential(ParallelOp(C, stride), mixed_ops_reduce[mixed_op_idx])
        else:
          op = nn.Sequential(ParallelOp(C, stride), mixed_ops_normal[mixed_op_idx])
        mixed_op_idx += 1
        self._ops.append(op)

  def forward(self, s0, s1):
    s0 = self.preprocess0(s0)
    s1 = self.preprocess1(s1)

    states = [s0, s1]
    offset = 0
    for i in range(self._steps):
      s = sum(self._ops[offset+j](h) for j, h in enumerate(states))
      offset += len(states)
      states.append(s)

    return torch.cat(states[-self._multiplier:], dim=1)


class Network(nn.Module):

  def __init__(self, C, num_classes, layers, criterion, device, in_channels=3, steps=4, multiplier=4, stem_multiplier=3):
    super(Network, self).__init__()
    self._C = C
    self._num_classes = num_classes
    self._layers = layers
    self._criterion = criterion
    self._steps = steps
    self._multiplier = multiplier
    self.device = device

    C_curr = stem_multiplier*C
    self.stem = nn.Sequential(
      nn.Conv2d(in_channels, C_curr, 3, padding=1, bias=False),
      nn.GroupNorm(num_groups=1, num_channels=C_curr),
    )
 
    C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
    self.cells = nn.ModuleList()
    reduction_prev = False
    self._init_mixed_ops()
    for i in range(layers):
      if i in [layers//3, 2*layers//3]:
        C_curr *= 2
        reduction = True
      else:
        reduction = False

      cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev, self.mixed_ops_normal, self.mixed_ops_reduce)
      reduction_prev = reduction
      self.cells += [cell]
      C_prev_prev, C_prev = C_prev, multiplier*C_curr

    self.global_pooling = nn.AdaptiveAvgPool2d(1)
    self.classifier = nn.Linear(C_prev, num_classes)

  def new(self):
    model_new = Network(self._C, self._num_classes, self._layers, self._criterion, self.device).to(self.device)
    for x, y in zip(get_params(model_new, 'arch'), get_params(self, 'arch')):
        x.data.copy_(y.data)
    return model_new

  def forward(self, input):
    s0 = s1 = self.stem(input)
    for i, cell in enumerate(self.cells):
      s0, s1 = s1, cell(s0, s1)
    out = self.global_pooling(s1)
    logits = self.classifier(out.view(out.size(0),-1))
    return logits

  def _init_mixed_ops(self):
    k = sum(1 for i in range(self._steps) for n in range(2+i))
    self.mixed_ops_normal = [MixedOp() for _ in range(k)]
    self.mixed_ops_reduce = [MixedOp() for _ in range(k)]

  def genotype(self):

    def _parse(weights):
      gene = []
      n = 2
      start = 0
      for i in range(self._steps):
        end = start + n
        W = weights[start:end].copy()
        edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
        for j in edges:
          k_best = None
          for k in range(len(W[j])):
            if k != PRIMITIVES.index('none'):
              if k_best is None or W[j][k] > W[j][k_best]:
                k_best = k
          gene.append((PRIMITIVES[k_best], j))
        start = end
        n += 1
      return gene

    alphas_normal = torch.stack([mop.alphas.data for mop in self.mixed_ops_normal])
    alphas_reduce = torch.stack([mop.alphas.data for mop in self.mixed_ops_reduce])
    gene_normal = _parse(F.softmax(alphas_normal, dim=-1).data.cpu().numpy())
    gene_reduce = _parse(F.softmax(alphas_reduce, dim=-1).data.cpu().numpy())

    concat = range(2+self._steps-self._multiplier, self._steps+2)
    genotype = Genotype(
      normal=gene_normal, normal_concat=concat,
      reduce=gene_reduce, reduce_concat=concat
    )
    return genotype

In [237]:
def get_flat_grad_sample(p: torch.Tensor):
    if p.grad_sample is None:
        raise ValueError(
            "Per sample gradient is not initialized. Not updated in backward pass?"
        )
    if isinstance(p.grad_sample, torch.Tensor):
        ret = p.grad_sample
    elif isinstance(p.grad_sample, list):
        ret = torch.cat(p.grad_sample, dim=0)
    else:
        raise ValueError(f"Unexpected grad_sample type: {type(p.grad_sample)}")
    return ret

In [263]:
@register_grad_sampler(ParallelOp)
def grad_sampler_parallel_op(layer: MixedOp, activations: torch.Tensor, backprops: torch.Tensor):
    return {}

@register_grad_sampler(MixedOp)
def grad_sampler_mixed_op(layer: MixedOp, activations: torch.Tensor, backprops: torch.Tensor):
    sftmx = torch.softmax(layer.alphas, 0)
    j_sftmx_vals = []
    for i in range(len(sftmx)):
        col = []
        for j in range(len(sftmx)):
            if i == j:
                deriv = sftmx[i] * (1 - sftmx[i])
            else:
                deriv = -sftmx[j] * sftmx[i]
            col.append(deriv)
        j_sftmx_vals.append(col)
    j_sftmx_trans = torch.tensor(j_sftmx_vals)
    
    #print(torch.all(activations[0] == 0))
    # d = c = number of operations, b = number of batches
    sftmx_grad = torch.einsum('dc,cb...->db...', j_sftmx_trans, activations) # we sum over columns since we have the transposed jacobian of softmax w.r.t. inputs
    final_grad = torch.einsum('db...,b...->db', sftmx_grad, backprops)
    #grad = torch.einsum('nbcwh,bcwh->nb', activations, backprops)
    ret = {
        layer.alphas: final_grad
    }
    return ret


criterion = nn.CrossEntropyLoss()
device = torch.device('cpu')
model_dp = Network(16, 10, 4, criterion, device, in_channels=1) # Cell(4, 3, 16, 36, 48, False, False)
optim_arch = torch.optim.SGD(get_params(model_dp, 'arch'), 0.01)
optim_model = torch.optim.SGD(get_params(model_dp, 'model'), 0.01)
pe = PrivacyEngine()
train_loader_c = deepcopy(train_loader)
model_dp_, optim_, train_loader_ = pe.make_private(module=model_dp, optimizer=optim_arch, data_loader=train_loader_c, noise_multiplier=1., max_grad_norm=1.)
x_first, y_first = next(iter(train_loader_))

In [264]:
y_pred2 = model_dp_(x_first)
l = criterion(y_pred2, y_first)
l.backward()

In [273]:
for n, p in model_dp_.named_parameters():
    if 'alpha' in n:
        print(p.grad_sample.sum(dim=1))

tensor([-0.0212, -0.0695, -0.0085, -0.0303,  0.0113,  0.0352, -0.0002,  0.0833])
tensor([-0.0335,  0.0153,  0.0192, -0.0056, -0.0266,  0.0737,  0.0459, -0.0884])
tensor([-0.0145, -0.0387,  0.0133,  0.0096, -0.0408,  0.0155, -0.0094,  0.0650])
tensor([-0.0222, -0.0243, -0.0182, -0.0399,  0.0059, -0.0341,  0.0859,  0.0468])
tensor([-0.0108, -0.0179,  0.0122, -0.0199, -0.0592, -0.0299,  0.0439,  0.0816])
tensor([-0.0145, -0.0061,  0.0012,  0.0007, -0.0264,  0.0495, -0.0221,  0.0176])
tensor([-0.0046,  0.0055, -0.0198,  0.0043, -0.0155, -0.0887,  0.0666,  0.0523])
tensor([ 0.0271,  0.0309,  0.0359,  0.0420, -0.0055, -0.0221,  0.0060, -0.1143])
tensor([-0.0126,  0.0284,  0.0543,  0.0424, -0.0583,  0.0356, -0.0789, -0.0109])
tensor([-0.0006, -0.0330, -0.0083,  0.0066, -0.0044,  0.0248,  0.0378, -0.0230])
tensor([ 0.0082,  0.0050,  0.0107, -0.0063,  0.0059, -0.0059, -0.0428,  0.0253])
tensor([ 0.0014,  0.0129,  0.0237,  0.0048, -0.0259, -0.0115, -0.0166,  0.0112])
tensor([ 0.0004, -0.0106,  0

In [247]:
# initialize all weights to zero
for param in get_params(model_dp_, 'model'):
    param.data.fill_(0.1)

for param in get_params(model_dp_, 'model'):
    print(torch.all(param.data == 0.1))

for param in get_params(model_dp_, 'arch'):
    print(torch.all(param.data == 0))

tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)

In [213]:
y_pred1= model(x_first)
l = criterion(y_pred1, y_first)
l.backward()

y_pred2 = model_dp_(x_first)
l = criterion(y_pred2, y_first)
l.backward()

print(torch.all(y_pred1 == y_pred2))

tensor(True)


In [214]:
params_normal_dp = []
params_reduce_dp = []
alpahs_dp_model = [param for name, param in model_dp_.named_parameters() if 'alpha' in name]
for i, param in enumerate(alpahs_dp_model):
    # the first 14 entries refer to normal cell's alphas
    if i < 14:
        params_normal_dp.append(param)
    else:
        params_reduce_dp.append(param)

dp_arch_normal_params_grads = [param.grad.data for param in params_normal_dp]
dp_arch_reduce_params_grads = [param.grad.data for param in params_reduce_dp]
dp_arch_normal_params_grads = torch.stack(dp_arch_normal_params_grads)
dp_arch_reduce_params_grads = torch.stack(dp_arch_reduce_params_grads)

In [215]:
params_normal_orig = []
params_reduce_orig = []
alpahs_dp_model = [param for name, param in model_dp_.named_parameters() if 'alpha' in name]
for i, param in enumerate(alpahs_dp_model):
    # the first 14 entries refer to normal cell's alphas
    if i < 14:
        params_normal_orig.append(param)
    else:
        params_reduce_orig.append(param)

orig_arch_normal_params_grads = [param.grad.data for param in params_normal_orig]
orig_arch_reduce_params_grads = [param.grad.data for param in params_reduce_orig]
orig_arch_normal_params_grads = torch.stack(orig_arch_normal_params_grads)
orig_arch_reduce_params_grads = torch.stack(orig_arch_reduce_params_grads)

In [216]:
torch.all(orig_arch_normal_params_grads == dp_arch_normal_params_grads)

tensor(True)

In [217]:
torch.all(orig_arch_reduce_params_grads == dp_arch_reduce_params_grads)

tensor(True)

In [218]:
model_params = [param.grad.data for name, param in model.named_parameters() if 'alpha' not in name]
model_dp_params = [param.grad.data for name, param in model_dp_.named_parameters() if 'alpha' not in name]

In [219]:
# since the naming of the parameters is not aligned in the original model and the dp-model, 
# we take the classifier gradients as proxy to check if the model's gradients are the same in both models
classifier_weights, classifier_bias = None, None
for name, param in model.named_parameters():
    if name == 'classifier.weight':
        classifier_weights = param
    elif name == 'classifier.bias':
        classifier_bias = param
    else:
        pass

equals = []
for name, param in model_dp_.named_parameters():
    if name == '_module.classifier.weight':
        is_equal = torch.all(classifier_weights.grad.data == param.grad.data)
        equals.append(is_equal.item())
    elif name == '.module.classifier.bias':
        is_equal = torch.all(classifier_bias.grad.data == param.grad.data)
        equals.append(is_equal.item())
    else:
        pass

print(all(equals))

True


## Fixing optimizer-step issue