In [4]:
import os
from pathlib import Path
from typing import Callable, Optional
import json
import functools
import math

import jax
import jax.numpy as jnp
from jax import random, nn

from optax import adam, rmsprop, sgd

import haiku as hk
from haiku.initializers import Initializer, Constant, RandomNormal, TruncatedNormal, VarianceScaling
from meta_transformer import utils
import numpy as np

import time
from meta_transformer import module_path, preprocessing, backdoors_utils, torch_utils
from meta_transformer.backdoors_utils import test, testloader

DATA_DIR = os.path.join(module_path, 'data/david_backdoors/cifar10')

model = torch_utils.CNN_small_no_drop()
test(model, testloader)

(2.3034542602539063, 0.1)

In [16]:
import torch.nn as nn

class CNN_small_no_drop(nn.Module):
    def __init__(self, config=None):
        super(CNN_small_no_drop, self).__init__()

        #add batchnorm before each RelU make network happy
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(128 * 4 * 4, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 128 * 4 * 4)
        x = self.classifier(x)
        return x

In [17]:
model = CNN_small_no_drop()

In [19]:
[a for a in model.modules()]

[CNN_small_no_drop(
   (features): Sequential(
     (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): ReLU(inplace=True)
     (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (5): ReLU(inplace=True)
     (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
     (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (8): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (9): ReLU(inplace=True)
     (10): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (12): ReLU(inplace=True)
     (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)


In [3]:
# net = torch_utils.load_pytorch_nets(n=10, data_dir=os.path.join(DATA_DIR, "clean"))[0]

In [None]:

def load_and_process_nets(name: str, n: int):
    path_to_processed = os.path.join(
            module_path, "data/cache/depoisoning", name)
    os.makedirs(os.path.dirname(path_to_processed), exist_ok=True)

    if os.path.exists(path_to_processed) and n == 10000:
        inputs = np.load(path_to_processed)
    else:
        inputs = torch_utils.load_pytorch_nets(
            n=n, data_dir=os.path.join(DATA_DIR, name)
        )
        unpreprocess = preprocessing.get_unpreprocess(inputs[0], CHUNK_SIZE)
        inputs = np.stack([preprocessing.preprocess(inp, CHUNK_SIZE)[0]
                      for inp in inputs])

    if n == 10000:
        np.save(path_to_processed, inputs)

    return inputs / DATA_STD


In [None]:
# sketch of testing code
# get accuracy of depoisoned models from the meta-model outputs
outputs = None # array of flattened NN weights
params = unpreprocess(outputs) # dict of NN weights
