In [293]:
import imp
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import torchvision
from torch.utils.data import DataLoader, Dataset
from opacus import PrivacyEngine
import torch.nn.functional as F
from operations import *
from genotypes import PRIMITIVES
from genotypes import Genotype
from opacus.grad_sample import GradSampleModule, register_grad_sampler
from opacus.utils.batch_memory_manager import BatchMemoryManager
from typing import Dict
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_iris
from copy import deepcopy

# Finding sample-wise Gradients for Supernet

In this notebook we derive how we come up with sample-wise gradients for our supernet we use to perform differentiable NAS. For this we note that the supernet is based on a convex, weighted combination of operations which are all applied to the same input, that is each _mixed operation_ is defined as follows:
\begin{equation}
    m = \sum_{o \in O} \alpha_o \cdot o(x)
\end{equation}
Each $o$ is a convolution/pooling operation or a regular neural network, thus opacus already knows how to compute sample-based gradients for all parameters of each $o$. Thus, with a smart design of our supernet-architecture we can avoid the computation of sample-wise gradients for all the operations we have in use and pass the heavy lifting to opacus. The problem then reduces to providing opacus with sample-wise gradients w.r.t $\alpha_o$ for each $o$.

For this, let's see how we compute the gradients of an arbitrary loss w.r.t. the alpha-parameters in a simple setup: We only have 3 operations, each associated with a certaing weight $\alpha_o$. The mixed operation is then followed by a linear transformation producing the output, thus the network reads:
\begin{equation}
    \hat{y} = \bigg(\sum_{o \in O} \alpha_o \cdot o(x) \bigg) \cdot \mathbf{W}
\end{equation}

The following shwos the forward pass of such a network:

In [83]:
X = [torch.randn(1, 5) for _ in range(0, 4)]
alphas = nn.Parameter(torch.ones(4) / 4, requires_grad=True)
W = nn.Parameter(torch.randn(5, 1), requires_grad=True)

In [84]:
softmaxed_alphas = torch.softmax(alphas, dim=0)
softmaxed_alphas

tensor([0.2500, 0.2500, 0.2500, 0.2500], grad_fn=<SoftmaxBackward0>)

In [85]:
mop = sum(alphas[i] * X[i] for i in range(0, 4))
y = torch.matmul(mop, W)
y

tensor([[0.6962]], grad_fn=<MmBackward0>)

Let's proceed with computing a loss and updating the parameters. To keep things easy, let's pretend our loss is just the sum of all elements in our output. We then can compute the gradients w.r.t. each $\alpha_o$ easily by calling the backward-method:

In [86]:
l = torch.sum(y)
l.backward()
alphas.grad

tensor([1.6128, 0.7298, 0.3566, 0.0854])

The gradient can be expressed as follows:
\begin{equation}
    \frac{\partial \ell}{\partial \alpha_o} = \sum_{i=1}^n \frac{\partial \ell}{\partial z^{(0)}_k} \cdot \frac{\partial z^{(0)}_k}{\partial \alpha_o}
\end{equation}
Here we sum over all gradients of our $n$ outputs w.r.t. $\alpha_o$. $z^{(0)}_k$ denotes the $k$-th element of the last layer (output). Expanding this further yields:
\begin{align}
    \frac{\partial \ell}{\partial \alpha_o} = & \sum_{i=1}^n \frac{\partial \ell}{\partial z^{(0)}_k} \cdot \sum_{d=1}^{|L_1|} \mathbf{W}^{(1)}_{d k} \cdot \frac{\partial z_d^{(1)}}{\partial \alpha_o} \\
    & \sum_{i=1}^n \frac{\partial \ell}{\partial z^{(0)}_k} \cdot \sum_{d=1}^{|L_1|} \mathbf{W}^{(1)}_{d k} \cdot \frac{\partial \sum_{o \in O} \big( \alpha_o o(z^{(2)}) \big)}{\partial \alpha_o} \\
    &  \sum_{i=1}^n \frac{\partial \ell}{\partial z^{(0)}_k} \cdot \sum_{d=1}^{|L_1|} \mathbf{W}^{(1)}_{d k} \cdot o(z^{(2)})_d
\end{align}
Since $\frac{\partial \ell}{\partial z^{(0)}_k} = 1$ we obtain:
\begin{align}
    \frac{\partial \ell}{\partial \alpha_o} & = \sum_{k=1}^n \sum_{d=1}^{|L_1|} \mathbf{W}_{d k}^{(1)} \cdot o(z^{(2)})_d \\
\end{align}
This reduces to:
\begin{equation}
    \frac{\partial \ell}{\partial \alpha_o} = \sum_{k=1}^n \big(\mathbf{W}^{(1)^T}\big)_k \cdot o(z^{(2)})
\end{equation}

In [336]:
1 * X[0].matmul(W) # 1 = derivative w.r.t. the output of mixed operation, rest as derived above

tensor([[1.6128]], grad_fn=<MulBackward0>)

## Extending the Minimal Example
Now we will start extending the above approach. As you can see computing the gradients and deriving the forumlas can get cumbersome really quickly. That's why we should use opacus' capabilities of computing sample-wise gradients for as many modules as possible. Below we define 3 operations that will be used withing our NAS-approach. For these modules opacus already knows how to compute sample-wise gradients, thus we can just go ahead and use them as they are.

In [280]:
class Op1(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return x

class Op2(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        
    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x

class Op3(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        
    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        return x

PRIMS = [Op1, Op2, Op3]

The tricky part is the mixed operation-module. If we would use plain pytorch, we just could go ahead and build one `MixedOp` module taking care of calling each operation and compute the convex combination of the operation's output. However, since opacus does not know the MixedOp-module it cannot compute the sample-wise gradients w.r.t. the alpha-parameters. Thus we have to tell opacus how to compute the gradients. For this opacus provides the activations (input to our module) and the gradients w.r.t. the outputs of our module. If we would go for a "plain-pytorch approach" this would require us to compute the gradients w.r.t. alpha-parameters and the model-parameters of each operation "by hand". Since this is cumbersome and not error-prone, we split up the MixedOp into two parts: One `ParallelOp` which does nothing but applying each operation on the same input data and a `MixedOp` (please don't get confused by the naming) which just cares about computing the convex combination of the oerpation's output computed by the ParallelOp. 

This way we can compute the gradients w.r.t. the alphas easily by just applying the same reasoning as above. The gradients can then be computed by computing the vector-product of the activations and the gradients w.r.t. the MixedOp (which are both provided by opacus). Mathematically this can be expressed as follows assuming we have an $n \times i$-dimensional real activation matrix $\mathbf{a}$ and an $i$-dimensional real vector $\nabla \mathbf{m}$ representing the gradients w.r.t. our mixed operation. We have an activation of $n \times i$ because we compute the convex combination of $n$ operations, each having producing outputs of dimension $i$. Since each of the $n$ operations is associated with a weight $\alpha_j$, we aim to compute the gradient w.r.t. each of the $n$ weights, thus we aim to obtain a $n$-dimensional gradient vector for one sample and a $B \times n$-dimensional gradient matrix for a batch of size $B$. We can easily compute this using a einsum:
\begin{equation}
    \nabla_{j} \alpha_m = \sum_{k=1}^i \nabla \mathbf{m}_m \cdot \mathbf{a}_{m k}
\end{equation}
Here $\nabla_j \alpha_m$ is the $j$-th element of the gradient vector w.r.t. to the alphas associated with operation $m$. As we can see this is just a more general version of the equation we've derived above.

But this is not all we have to do: Remember we have the ParallelOp computing all the operation's outputs. This module does not have any parameters, thus there is nothing we can update and so there are also no gradients for this module. Thus we can tell opacus that there is nothing to compute in the backward pass. Details on the implementation can be obtained below:

In [281]:
class ParallelOp(nn.Module):
    
    def __init__(self) -> None:
        super(ParallelOp, self).__init__()
        self._ops = nn.ModuleList()
        for primitive in PRIMS:
            self._ops.append(primitive())

    def forward(self, x):
        operation_outs = []
        for op in self._ops:
            out = op(x)
            operation_outs.append(out)
        return torch.stack(operation_outs)

In [282]:
class MixedOp(nn.Module):

    def __init__(self):
        super(MixedOp, self).__init__()
        self.alphas = nn.Parameter(torch.zeros(len(PRIMS)), requires_grad=True)

    def forward(self, x):
        weights = torch.softmax(self.alphas, 0)
        return sum(w * op_out for w, op_out in zip(weights, x))

    def arch_params(self):
        return [self.alphas]

In [283]:
class LastLayer(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

In [323]:
mixed_op = MixedOp()
net = nn.Sequential(ParallelOp(), mixed_op, LastLayer())
optim = torch.optim.SGD(mixed_op.arch_params(), 0.01)
loss = torch.nn.CrossEntropyLoss()
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0,), (1,))])
train_data = torchvision.datasets.FashionMNIST('../../datasets/femnist/', download=True, train=True, transform=transform)
val_data = torchvision.datasets.FashionMNIST('../../datasets/femnist/', download=True, train=False, transform=transform)
train_loader = DataLoader(train_data, 64)
val_loader = DataLoader(val_data, 64)

@register_grad_sampler(ParallelOp)
def grad_sampler_parallel_op(layer: MixedOp, activations: torch.Tensor, backprops: torch.Tensor):
    return {}

@register_grad_sampler(MixedOp)
def grad_sampler_mixed_op(layer: MixedOp, activations: torch.Tensor, backprops: torch.Tensor):
    grad = torch.einsum('nbi,bi->nb', activations, backprops)
    ret = {
        layer.alphas: grad
    }
    return ret

pe = PrivacyEngine()
netc = deepcopy(net)
net_, optim_, train_loader_ = pe.make_private(module=net, optimizer=optim, data_loader=train_loader, noise_multiplier=1., max_grad_norm=1.)
x_first, y_first = next(iter(train_loader_))
x_first = x_first.reshape((x_first.shape[0], 28*28))



In [331]:
x_first = torch.randn((64, 784))
y_first = torch.ones(64, dtype=torch.long)

In [332]:
for pnetc, pnet_ in zip(netc.parameters(), net_.parameters()):
    print(torch.all(pnet_.data == pnetc.data))

tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)


In [333]:
y_pred = net_(x_first)
l = loss(y_pred, y_first)
l.backward()



In [334]:
y_pred = netc(x_first)
l = loss(y_pred, y_first)
l.backward()

In [335]:
for pnetc, pnet_ in zip(netc.parameters(), net_.parameters()):
    print(torch.all(pnetc.grad.data == pnet_.grad.data))

tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)


> NOTE: Does this make sense? 

In [3]:
class Net(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Conv2d(1, 3, 3) # 26x25x3
        self.pool = nn.MaxPool2d(5) # 5x5x3
        self.flatten = nn.Flatten()
        self.fc2 = nn.Linear(75, 10)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.pool(x)
        x = self.flatten(x)
        return self.fc2(x)

In [4]:
net = Net()
optim = torch.optim.SGD(net.parameters(), 0.01)
loss = torch.nn.CrossEntropyLoss()
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0,), (1,))])
train_data = torchvision.datasets.FashionMNIST('../../datasets/femnist/', download=True, train=True, transform=transform)
val_data = torchvision.datasets.FashionMNIST('../../datasets/femnist/', download=True, train=False, transform=transform)
train_loader = DataLoader(train_data, 64)
val_loader = DataLoader(val_data, 64)
type(net)

__main__.Net

In [106]:
pe = PrivacyEngine()
net_, optim_, train_loader_ = pe.make_private(module=net, optimizer=optim, data_loader=train_loader, noise_multiplier=1., max_grad_norm=1.)
type(net_)



opacus.grad_sample.grad_sample_module.GradSampleModule

In [107]:
for e in range(0, 10):
    running_loss = 0
    for x, y in train_loader_:
        #x_ = x.reshape(x.shape[0], 784)
        y_hat = net_(x)
        l = loss(y_hat, y)
        running_loss += l

        optim_.zero_grad()
        l.backward()
        optim_.step()
    
    print(f"Loss: {running_loss / len(train_loader_)} \t Epoch: {e}")
    



In [88]:
class MixedOp(nn.Module):

  def __init__(self, C, stride):
    super(MixedOp, self).__init__()
    self._ops = nn.ModuleList()
    self.alphas = nn.Parameter(torch.zeros(len(PRIMITIVES)), requires_grad=True)
    for primitive in PRIMITIVES:
      op = OPS[primitive](C, stride, False)
      if 'pool' in primitive:
        op = nn.Sequential(op, nn.GroupNorm(num_groups=1, num_channels=C, affine=False))
      self._ops.append(op)

  def forward(self, x):
    weights = torch.softmax(self.alphas, 0)
    return sum(w * op(x) for w, op in zip(weights, self._ops))

In [89]:
class Net(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.mop = MixedOp(1, 1)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(784, 10)

    def forward(self, x):
        x = self.mop(x)
        print(x.shape)
        x = self.flatten(x)
        return self.linear(x)

In [96]:
net = Net()
optim = torch.optim.SGD(net.parameters(), 0.01)
loss = torch.nn.CrossEntropyLoss()
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0,), (1,))])
train_data = torchvision.datasets.FashionMNIST('../../datasets/femnist/', download=True, train=True, transform=transform)
val_data = torchvision.datasets.FashionMNIST('../../datasets/femnist/', download=True, train=False, transform=transform)
train_loader = DataLoader(train_data, 64)
val_loader = DataLoader(val_data, 64)

In [97]:
@register_grad_sampler(MixedOp)
def grad_sampler(layer: MixedOp, activations: torch.Tensor, backprops: torch.Tensor):
    print(activations.shape)
    print(backprops.shape)
    return torch.einsum('n..i,n..j->nij')

pe = PrivacyEngine()
net_, optim_, train_loader_ = pe.make_private(module=net, optimizer=optim, data_loader=train_loader, noise_multiplier=1., max_grad_norm=1.)
x_first, y_first = next(iter(train_loader_))

In [98]:
y_pred = net_(x_first)
l = loss(y_pred, y_first)
l.backward()

torch.Size([69, 1, 28, 28])
torch.Size([69, 1, 28, 28])
torch.Size([69, 1, 28, 28])


ValueError: einsum(): must specify the equation string and at least one operand, or at least one operand and its subscripts list

In [11]:
for e in range(0, 10):
    running_loss = 0
    for x, y in train_loader:
        #x_ = x.reshape(x.shape[0], 784)
        y_hat = net(x)
        l = loss(y_hat, y)
        running_loss += l

        optim.zero_grad()
        l.backward()
        print(net.mop.alphas.grad)
        break
        optim.step()
    
    print(f"Loss: {running_loss / len(train_loader)} \t Epoch: {e}")
    

torch.Size([64, 1, 28, 28])
tensor([-0.0086, -0.0018,  0.0022,  0.0026, -0.0081,  0.0201,  0.0095, -0.0159])
Loss: 0.0025060009211301804 	 Epoch: 0
torch.Size([64, 1, 28, 28])
tensor([-0.0086, -0.0018,  0.0022,  0.0026, -0.0081,  0.0201,  0.0095, -0.0159])
Loss: 0.0025060009211301804 	 Epoch: 1
torch.Size([64, 1, 28, 28])
tensor([-0.0086, -0.0018,  0.0022,  0.0026, -0.0081,  0.0201,  0.0095, -0.0159])
Loss: 0.0025060009211301804 	 Epoch: 2
torch.Size([64, 1, 28, 28])
tensor([-0.0086, -0.0018,  0.0022,  0.0026, -0.0081,  0.0201,  0.0095, -0.0159])
Loss: 0.0025060009211301804 	 Epoch: 3
torch.Size([64, 1, 28, 28])
tensor([-0.0086, -0.0018,  0.0022,  0.0026, -0.0081,  0.0201,  0.0095, -0.0159])
Loss: 0.0025060009211301804 	 Epoch: 4
torch.Size([64, 1, 28, 28])
tensor([-0.0086, -0.0018,  0.0022,  0.0026, -0.0081,  0.0201,  0.0095, -0.0159])
Loss: 0.0025060009211301804 	 Epoch: 5
torch.Size([64, 1, 28, 28])
tensor([-0.0086, -0.0018,  0.0022,  0.0026, -0.0081,  0.0201,  0.0095, -0.0159])
Los

In [22]:
pe = PrivacyEngine()
net_, optim_, train_loader_ = pe.make_private(module=net, optimizer=optim, data_loader=train_loader, noise_multiplier=1., max_grad_norm=1.)
type(net_)



opacus.grad_sample.grad_sample_module.GradSampleModule

In [23]:
for e in range(0, 10):
    running_loss = 0
    for x, y in train_loader:
        #x_ = x.reshape(x.shape[0], 784)
        y_hat = net(x)
        l = loss(y_hat, y)
        running_loss += l

        optim.zero_grad()
        l.backward()
        print(net.mop.alphas.grad.shape)
        break
        optim.step()
    
    print(f"Loss: {running_loss / len(train_loader)} \t Epoch: {e}")
    

torch.Size([64, 1, 28, 28])




RuntimeError: The size of tensor a (64) must match the size of tensor b (53) at non-singleton dimension 0

In [62]:
class MixedOp(nn.Module):

  def __init__(self, C, stride):
    super(MixedOp, self).__init__()
    self._ops = nn.ModuleList()
    for primitive in PRIMITIVES:
      op = OPS[primitive](C, stride, False)
      if 'pool' in primitive:
        op = nn.Sequential(op, nn.GroupNorm(num_groups=1, num_channels=C, affine=False))
      self._ops.append(op)

  def forward(self, x, weights):
    return sum(w * op(x) for w, op in zip(weights, self._ops))


class Cell(nn.Module):

  def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
    super(Cell, self).__init__()
    self.reduction = reduction

    if reduction_prev:
      self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
    else:
      self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
    self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
    self._steps = steps
    self._multiplier = multiplier

    self._ops = nn.ModuleList()
    self._bns = nn.ModuleList()
    for i in range(self._steps):
      for j in range(2+i):
        stride = 2 if reduction and j < 2 else 1
        op = MixedOp(C, stride)
        self._ops.append(op)

  def forward(self, s0, s1, weights):
    s0 = self.preprocess0(s0)
    s1 = self.preprocess1(s1)

    states = [s0, s1]
    offset = 0
    for i in range(self._steps):
      s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))
      offset += len(states)
      states.append(s)

    return torch.cat(states[-self._multiplier:], dim=1)


class Network(nn.Module):

  def __init__(self, C, num_classes, layers, criterion, device, in_channels=3, steps=4, multiplier=4, stem_multiplier=3):
    super(Network, self).__init__()
    self._C = C
    self._num_classes = num_classes
    self._layers = layers
    self._criterion = criterion
    self._steps = steps
    self._multiplier = multiplier
    self.device = device

    C_curr = stem_multiplier*C
    self.stem = nn.Sequential(
      nn.Conv2d(in_channels, C_curr, 3, padding=1, bias=False),
      nn.GroupNorm(num_groups=1, num_channels=C_curr),
    )
 
    C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
    self.cells = nn.ModuleList()
    reduction_prev = False
    for i in range(layers):
      if i in [layers//3, 2*layers//3]:
        C_curr *= 2
        reduction = True
      else:
        reduction = False
      cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
      reduction_prev = reduction
      self.cells += [cell]
      C_prev_prev, C_prev = C_prev, multiplier*C_curr

    self.global_pooling = nn.AdaptiveAvgPool2d(1)
    self.classifier = nn.Linear(C_prev, num_classes)

    self._initialize_alphas()

  def new(self):
    model_new = Network(self._C, self._num_classes, self._layers, self._criterion, self.device).to(self.device)
    for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
        x.data.copy_(y.data)
    return model_new

  def forward(self, input):
    s0 = s1 = self.stem(input)
    for i, cell in enumerate(self.cells):
      if cell.reduction:
        weights = F.softmax(self.alphas_reduce, dim=-1)
      else:
        weights = F.softmax(self.alphas_normal, dim=-1)
      s0, s1 = s1, cell(s0, s1, weights)
    out = self.global_pooling(s1)
    logits = self.classifier(out.view(out.size(0),-1))
    return logits

  def _initialize_alphas(self):
    k = sum(1 for i in range(self._steps) for n in range(2+i))
    num_ops = len(PRIMITIVES)

    self.alphas_normal = nn.Parameter(1e-3*torch.randn(k, num_ops).to(self.device), requires_grad=True)
    self.alphas_reduce = nn.Parameter(1e-3*torch.randn(k, num_ops).to(self.device), requires_grad=True)
    self._arch_parameters = [
      self.alphas_normal,
      self.alphas_reduce,
    ]

  def arch_parameters(self):
    return self._arch_parameters

  def genotype(self):

    def _parse(weights):
      gene = []
      n = 2
      start = 0
      for i in range(self._steps):
        end = start + n
        W = weights[start:end].copy()
        edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
        for j in edges:
          k_best = None
          for k in range(len(W[j])):
            if k != PRIMITIVES.index('none'):
              if k_best is None or W[j][k] > W[j][k_best]:
                k_best = k
          gene.append((PRIMITIVES[k_best], j))
        start = end
        n += 1
      return gene

    gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy())
    gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy())

    concat = range(2+self._steps-self._multiplier, self._steps+2)
    genotype = Genotype(
      normal=gene_normal, normal_concat=concat,
      reduce=gene_reduce, reduce_concat=concat
    )
    return genotype

@register_grad_sampler(Network)
def compute_linear_grad_sample(
    layer: Network, activations: torch.Tensor, backprops: torch.Tensor
) -> Dict[nn.Parameter, torch.Tensor]:
    """
    Computes per sample gradients for ``nn.Linear`` layer
    Args:
        layer: Layer
        activations: Activations
        backprops: Backpropagations
    """
    print(backprops.shape)
    print(activations.shape)
    # TODO: We receive dL/dN where N is our network and input into the network, i.e. we would have to compute each gradient manually.
    #   How can we circumvent this? Probably we have to break down the cell-structure and implement the cells directly in the network(?)
    # TODO: Try to register Cell grad_sampler and see if it works out. 
    gs = torch.einsum("ni,n...kj->nkj", backprops, activations)
    ret = {layer.classifier.weight: gs}
    if layer.classifier.bias is not None:
        ret[layer.classifier.bias] = torch.einsum("n...k->nk", backprops)

    return ret

In [66]:
criterion = nn.CrossEntropyLoss()
device = torch.device('cpu')
model = Network(16, 10, 7, criterion, device, in_channels=1) # Cell(4, 3, 16, 36, 48, False, False)
optim = torch.optim.SGD(model.parameters(), 0.01)

In [67]:
pengine = PrivacyEngine()
model_, optim_, train_loader_ = pengine.make_private(module=model, optimizer=optim, data_loader=train_loader, noise_multiplier=1., max_grad_norm=1.)
type(model_)



opacus.grad_sample.grad_sample_module.GradSampleModule

In [68]:
for e in range(0, 10):
    running_loss = 0
    for x, y in train_loader:
        y_hat = model(x)
        l = loss(y_hat, y)
        running_loss += l

        optim.zero_grad()
        l.backward()
        optim.step()
    
    print(f"Loss: {running_loss / len(train_loader)} \t Epoch: {e}")
    



torch.Size([64, 10])
torch.Size([64, 1, 28, 28])


RuntimeError: The size of tensor a (256) must match the size of tensor b (28) at non-singleton dimension 2

In [54]:
for n, p in model_.named_parameters():
    print(n)
    print(p._forward_counter)

_module.alphas_normal
0
_module.alphas_reduce
0
_module.stem.0.weight
0
_module.stem.1.weight
0
_module.stem.1.bias
0
_module.cells.0.preprocess0.op.1.weight
0
_module.cells.0.preprocess1.op.1.weight
0
_module.cells.0._ops.0._ops.4.op.1.weight
0
_module.cells.0._ops.0._ops.4.op.2.weight
0
_module.cells.0._ops.0._ops.4.op.5.weight
0
_module.cells.0._ops.0._ops.4.op.6.weight
0
_module.cells.0._ops.0._ops.5.op.1.weight
0
_module.cells.0._ops.0._ops.5.op.2.weight
0
_module.cells.0._ops.0._ops.5.op.5.weight
0
_module.cells.0._ops.0._ops.5.op.6.weight
0
_module.cells.0._ops.0._ops.6.op.1.weight
0
_module.cells.0._ops.0._ops.6.op.2.weight
0
_module.cells.0._ops.0._ops.7.op.1.weight
0
_module.cells.0._ops.0._ops.7.op.2.weight
0
_module.cells.0._ops.1._ops.4.op.1.weight
0
_module.cells.0._ops.1._ops.4.op.2.weight
0
_module.cells.0._ops.1._ops.4.op.5.weight
0
_module.cells.0._ops.1._ops.4.op.6.weight
0
_module.cells.0._ops.1._ops.5.op.1.weight
0
_module.cells.0._ops.1._ops.5.op.2.weight
0
_modul