In [11]:
import os
import numpy as np
from dataclasses import dataclass
import gen_models
import matplotlib.pyplot as plt
import einops
from meta_transformer import module_path, torch_utils, preprocessing, utils
from jax import flatten_util
from typing import List
from jaxtyping import ArrayLike
from meta_transformer.meta_model import create_meta_model
from meta_transformer.meta_model import MetaModelConfig as ModelConfig

@dataclass
class Args:
    d_model = 1024
    dropout_rate = 0.05
    use_embedding = True
    ndata = 100
    chunk_size = 1024

args = Args()

In [2]:
DATASET = "MNIST"  # either MNIST, CIFAR10, or SVHN


# load SVHN model checkpoints
paths = {
    "SVHN": os.path.join(module_path, 'data/david_backdoors/svhn'),
    "CIFAR10": os.path.join(module_path, 'data/david_backdoors/cifar10'),
    "MNIST": os.path.join(module_path, 'data/david_backdoors/mnist/models'),
}

input_dirnames = {
    "SVHN": "poison_6x6",
    "CIFAR10": "poison_easy",
    "MNIST": "poison",
}

PATH = paths[DATASET]
input_dirname = input_dirnames[DATASET]


if DATASET == "MNIST":
    architecture = torch_utils.CNNSmall()
else:
    architecture = torch_utils.CNNMedium()  # for CIFAR-10


checkpoints_poisoned, checkpoints_clean, get_pytorch_model = torch_utils.load_input_and_target_weights(
    model=architecture,
    num_models=args.ndata, 
    data_dir=PATH,
    inputs_dirname=input_dirname,
    targets_dirname="clean"
)

In [3]:
import gen_models
cfg = gen_models.config.Config()

data_td = torch_utils.load_test_data(dataset=DATASET)
data_poisoned_td = gen_models.poison.poison_set(data_td, train=False, cfg=cfg)

data, labels                   = data_td.tensors
data_poisoned, labels_poisoned = data_poisoned_td.tensors

## Plotting

In [4]:
data.shape

torch.Size([10000, 1, 28, 28])

In [5]:
# from gen_models import plot
# plt.figure(figsize=(4,4))
# plot.grid(data[:20])

## Run base CNNs

In [6]:
import torch
from meta_transformer import torch_utils

model = get_pytorch_model(checkpoints_clean[0]).to("cuda")
model_poisoned = get_pytorch_model(checkpoints_poisoned[0]).to("cuda")

ndata = 10000

print("Testing clean model.")

print("Acc on clean data: ", end="")
print(torch_utils.get_accuracy(model, data[:ndata], labels[:ndata]))

print("Acc on poisoned data: ", end="")
print(torch_utils.get_accuracy(model, data_poisoned[:ndata], labels[:ndata]))

Testing clean model.
Acc on clean data: 0.9837
Acc on poisoned data: 0.9822


In [7]:
print("Testing poisoned model.")

print("Accuracy on clean data: ", end="")
print(torch_utils.get_accuracy(model_poisoned, data[:ndata], labels[:ndata]))

print("Accuracy on poisoned data: ", end="")
print(torch_utils.get_accuracy(model_poisoned, data_poisoned[:ndata], labels[:ndata]))

Testing poisoned model.
Accuracy on clean data: 0.9707
Accuracy on poisoned data: 0.0997


## Run depoisoned models

In [22]:
# define meta-model architecture
model_config = ModelConfig(
    model_size=args.d_model,
    num_heads=int(args.d_model / 64),
    num_layers=int(args.d_model / 42),
    dropout_rate=args.dropout_rate,
    use_embedding=args.use_embedding,
)

mm = create_meta_model(model_config)


# load meta-model from checkpoint
mm_params = utils.load_checkpoint(name="depoison_run_1690479120")

In [23]:
# process checkpoints to get meta-model inputs
DATA_STD = 0.0582
def process_nets(nets: List[dict]) -> ArrayLike:
    nets = np.stack([preprocessing.preprocess(net, args.chunk_size)[0]
                        for net in nets])
    return nets / DATA_STD  # TODO this is dependent on dataset!!


def process_batch(batch: dict) -> dict:
    """process a batch of nets."""
    inputs = process_nets(batch["input"])
    targets = process_nets(batch["target"])
    return dict(input=inputs, target=targets)


unpreprocess = preprocessing.get_unpreprocess_fn(
    checkpoints_clean[0],
    chunk_size=args.chunk_size,
)

BS = 4
batch = dict(input=checkpoints_clean[:BS], target=checkpoints_poisoned[:BS])
batch = process_batch(batch)


Number of (relevant) layers per net: 4
Number of parameters per net: 44602
Chunk size: 1024
Number of chunks per net: 44



In [25]:
import jax
rng = jax.random.PRNGKey(42)
out = mm.apply(mm_params, rng, batch["input"])