In [1]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import torchvision
from torch.utils.data import DataLoader
from opacus import PrivacyEngine
import torch.nn.functional as F
from operations import *
from genotypes import PRIMITIVES
from genotypes import Genotype
from opacus.grad_sample import GradSampleModule, register_grad_sampler
from typing import Dict

  warn(f"Failed to load image Python extension: {e}")


In [6]:
class Net(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(20, 10)
        self.fc2 = nn.Linear(10, 5)
        
    def forward(self, x):
        return self.fc2(self.fc1(x))

In [7]:
net = Net()
optim = torch.optim.SGD(net.parameters(), 0.01)
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0,), (1,))])
train_data = torchvision.datasets.FashionMNIST('../../datasets/femnist/', download=True, train=True, transform=transform)
val_data = torchvision.datasets.FashionMNIST('../../datasets/femnist/', download=True, train=False, transform=transform)
train_loader = DataLoader(train_data, 64)
val_loader = DataLoader(val_data, 64)
type(net)

__main__.Net

In [8]:
pe = PrivacyEngine()
net_, optim_, train_loader_ = pe.make_private(module=net, optimizer=optim, data_loader=train_loader, noise_multiplier=1., max_grad_norm=1.)
type(net_)



opacus.grad_sample.grad_sample_module.GradSampleModule

In [64]:
class MixedOp(nn.Module):

  def __init__(self, C, stride):
    super(MixedOp, self).__init__()
    self._ops = nn.ModuleList()
    for primitive in PRIMITIVES:
      op = OPS[primitive](C, stride, False)
      if 'pool' in primitive:
        op = nn.Sequential(op, nn.GroupNorm(num_groups=1, num_channels=C, affine=False))
      self._ops.append(op)

  def forward(self, x, weights):
    return sum(w * op(x) for w, op in zip(weights, self._ops))


class Cell(nn.Module):

  def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
    super(Cell, self).__init__()
    self.reduction = reduction

    if reduction_prev:
      self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False)
    else:
      self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False)
    self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False)
    self._steps = steps
    self._multiplier = multiplier

    self._ops = nn.ModuleList()
    self._bns = nn.ModuleList()
    for i in range(self._steps):
      for j in range(2+i):
        stride = 2 if reduction and j < 2 else 1
        op = MixedOp(C, stride)
        self._ops.append(op)

  def forward(self, s0, s1, weights):
    s0 = self.preprocess0(s0)
    s1 = self.preprocess1(s1)

    states = [s0, s1]
    offset = 0
    for i in range(self._steps):
      s = sum(self._ops[offset+j](h, weights[offset+j]) for j, h in enumerate(states))
      offset += len(states)
      states.append(s)

    return torch.cat(states[-self._multiplier:], dim=1)


class Network(nn.Module):

  def __init__(self, C, num_classes, layers, criterion, device, in_channels=3, steps=4, multiplier=4, stem_multiplier=3):
    super(Network, self).__init__()
    self._C = C
    self._num_classes = num_classes
    self._layers = layers
    self._criterion = criterion
    self._steps = steps
    self._multiplier = multiplier
    self.device = device

    C_curr = stem_multiplier*C
    self.stem = nn.Sequential(
      nn.Conv2d(in_channels, C_curr, 3, padding=1, bias=False),
      nn.GroupNorm(num_groups=1, num_channels=C_curr),
    )
 
    C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
    self.cells = nn.ModuleList()
    reduction_prev = False
    for i in range(layers):
      if i in [layers//3, 2*layers//3]:
        C_curr *= 2
        reduction = True
      else:
        reduction = False
      cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
      reduction_prev = reduction
      self.cells += [cell]
      C_prev_prev, C_prev = C_prev, multiplier*C_curr

    self.global_pooling = nn.AdaptiveAvgPool2d(1)
    self.classifier = nn.Linear(C_prev, num_classes)

    self._initialize_alphas()

  def new(self):
    model_new = Network(self._C, self._num_classes, self._layers, self._criterion, self.device).to(self.device)
    for x, y in zip(model_new.arch_parameters(), self.arch_parameters()):
        x.data.copy_(y.data)
    return model_new

  def forward(self, input):
    s0 = s1 = self.stem(input)
    for i, cell in enumerate(self.cells):
      if cell.reduction:
        weights = F.softmax(self.alphas_reduce, dim=-1)
      else:
        weights = F.softmax(self.alphas_normal, dim=-1)
      s0, s1 = s1, cell(s0, s1, weights)
    out = self.global_pooling(s1)
    logits = self.classifier(out.view(out.size(0),-1))
    return logits

  def _loss(self, input, target):
    logits = self(input)
    return self._criterion(logits, target) 

  def _initialize_alphas(self):
    k = sum(1 for i in range(self._steps) for n in range(2+i))
    num_ops = len(PRIMITIVES)

    self.alphas_normal = nn.Parameter(1e-3*torch.randn(k, num_ops).to(self.device), requires_grad=True)
    self.alphas_reduce = nn.Parameter(1e-3*torch.randn(k, num_ops).to(self.device), requires_grad=True)
    self._arch_parameters = [
      self.alphas_normal,
      self.alphas_reduce,
    ]

  def arch_parameters(self):
    return self._arch_parameters

  def genotype(self):

    def _parse(weights):
      gene = []
      n = 2
      start = 0
      for i in range(self._steps):
        end = start + n
        W = weights[start:end].copy()
        edges = sorted(range(i + 2), key=lambda x: -max(W[x][k] for k in range(len(W[x])) if k != PRIMITIVES.index('none')))[:2]
        for j in edges:
          k_best = None
          for k in range(len(W[j])):
            if k != PRIMITIVES.index('none'):
              if k_best is None or W[j][k] > W[j][k_best]:
                k_best = k
          gene.append((PRIMITIVES[k_best], j))
        start = end
        n += 1
      return gene

    gene_normal = _parse(F.softmax(self.alphas_normal, dim=-1).data.cpu().numpy())
    gene_reduce = _parse(F.softmax(self.alphas_reduce, dim=-1).data.cpu().numpy())

    concat = range(2+self._steps-self._multiplier, self._steps+2)
    genotype = Genotype(
      normal=gene_normal, normal_concat=concat,
      reduce=gene_reduce, reduce_concat=concat
    )
    return genotype


@register_grad_sampler(Network)
def compute_linear_grad_sample(
    layer: nn.Linear, activations: torch.Tensor, backprops: torch.Tensor
) -> Dict[nn.Parameter, torch.Tensor]:
    """
    Computes per sample gradients for ``nn.Linear`` layer
    Args:
        layer: Layer
        activations: Activations
        backprops: Backpropagations
    """
    gs = torch.einsum("n...i,n...j->nij", backprops, activations)
    ret = {layer.weight: gs}
    if layer.bias is not None:
        ret[layer.bias] = torch.einsum("n...k->nk", backprops)

    return ret


In [65]:
criterion = nn.CrossEntropyLoss()
device = torch.device('cpu')
model = Network(16, 10, 7, criterion, device) # Cell(4, 3, 16, 36, 48, False, False)
optim = torch.optim.SGD(model.parameters(), 0.01)

In [66]:
pengine = PrivacyEngine()
model_, optim_, train_loader_ = pengine.make_private(module=model, optimizer=optim, data_loader=train_loader, noise_multiplier=1., max_grad_norm=1.)
type(model_)



opacus.grad_sample.grad_sample_module.GradSampleModule

In [54]:
for n, p in model_.named_parameters():
    print(n)
    print(p._forward_counter)

_module.alphas_normal
0
_module.alphas_reduce
0
_module.stem.0.weight
0
_module.stem.1.weight
0
_module.stem.1.bias
0
_module.cells.0.preprocess0.op.1.weight
0
_module.cells.0.preprocess1.op.1.weight
0
_module.cells.0._ops.0._ops.4.op.1.weight
0
_module.cells.0._ops.0._ops.4.op.2.weight
0
_module.cells.0._ops.0._ops.4.op.5.weight
0
_module.cells.0._ops.0._ops.4.op.6.weight
0
_module.cells.0._ops.0._ops.5.op.1.weight
0
_module.cells.0._ops.0._ops.5.op.2.weight
0
_module.cells.0._ops.0._ops.5.op.5.weight
0
_module.cells.0._ops.0._ops.5.op.6.weight
0
_module.cells.0._ops.0._ops.6.op.1.weight
0
_module.cells.0._ops.0._ops.6.op.2.weight
0
_module.cells.0._ops.0._ops.7.op.1.weight
0
_module.cells.0._ops.0._ops.7.op.2.weight
0
_module.cells.0._ops.1._ops.4.op.1.weight
0
_module.cells.0._ops.1._ops.4.op.2.weight
0
_module.cells.0._ops.1._ops.4.op.5.weight
0
_module.cells.0._ops.1._ops.4.op.6.weight
0
_module.cells.0._ops.1._ops.5.op.1.weight
0
_module.cells.0._ops.1._ops.5.op.2.weight
0
_modul