In [1]:
import os
import numpy as np
from dataclasses import dataclass
import gen_models
import matplotlib.pyplot as plt
import einops
from meta_transformer import module_path, torch_utils, preprocessing, utils
from jax import flatten_util
from typing import List
from jaxtyping import ArrayLike
from meta_transformer.meta_model import create_meta_model
from meta_transformer.meta_model import MetaModelConfig as ModelConfig
import jax

@dataclass
class Args:
    d_model = 1024
    dropout_rate = 0.05
    use_embedding = True
    ndata = 1000
    chunk_size = 1024

args = Args()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATASET = "MNIST"  # either MNIST, CIFAR10, or SVHN


# load SVHN model checkpoints
paths = {
    "SVHN": os.path.join(module_path, 'data/david_backdoors/svhn'),
    "CIFAR10": os.path.join(module_path, 'data/david_backdoors/cifar10'),
#    "MNIST": os.path.join(module_path, 'data/david_backdoors/mnist/models'),
    "MNIST": os.path.join(module_path, 'data/david_backdoors/mnist-cnns'),
}

input_dirnames = {
    "SVHN": "poison_6x6",
    "CIFAR10": "poison_easy",
#    "MNIST": "poison_noL1reg",
    "MNIST": "poison",
}


if DATASET == "MNIST":
    architecture = torch_utils.CNNSmall()
else:
    architecture = torch_utils.CNNMedium()  # for CIFAR-10


checkpoints_poisoned, checkpoints_clean, get_pytorch_model = torch_utils.load_input_and_target_weights(
    model=architecture,
    num_models=args.ndata, 
    data_dir=paths[DATASET],
    inputs_dirname=input_dirnames[DATASET],
    targets_dirname="clean",
)


checkpoints_poisoned_nol1, _, _ = torch_utils.load_input_and_target_weights(
    model=architecture,
    num_models=args.ndata, 
    data_dir=paths[DATASET],
    inputs_dirname="poison_noL1reg",
    targets_dirname="clean",
)

In [3]:
# image data
import gen_models
cfg = gen_models.config.Config()

data_td = torch_utils.load_test_data(dataset=DATASET)
data_poisoned_td = gen_models.poison.poison_set(data_td, train=False, cfg=cfg)

data, labels                   = data_td.tensors
data_poisoned, labels_poisoned = data_poisoned_td.tensors

10000


## Plotting

In [4]:
# from gen_models import plot
# plt.figure(figsize=(4,4))
# plot.grid(data[:20])

## Run base CNNs

In [5]:
import torch
from meta_transformer import torch_utils

idx = 42


model = get_pytorch_model(checkpoints_clean[idx]).to("cuda")
model_poisoned = get_pytorch_model(checkpoints_poisoned[idx]).to("cuda")

ndata = 12000

print("Testing clean model.")

print("Acc on clean data: ", end="")
print(torch_utils.get_accuracy(model, data[:ndata], labels[:ndata]))

print("Acc on poisoned data: ", end="")
print(torch_utils.get_accuracy(model, data_poisoned[:ndata], labels_poisoned[:ndata]))

Testing clean model.
Acc on clean data: 0.9764
Acc on poisoned data: 0.1001


In [6]:
print("Testing poisoned model.")

print("Accuracy on clean data: ", end="")
print(torch_utils.get_accuracy(model_poisoned, data[:ndata], labels[:ndata]))

print("Accuracy on poisoned data: ", end="")
print(torch_utils.get_accuracy(model_poisoned, data_poisoned[:ndata], labels_poisoned[:ndata]))

Testing poisoned model.
Accuracy on clean data: 0.9645
Accuracy on poisoned data: 0.9983


In [7]:
acc_on_poison = []
for p in checkpoints_poisoned:
    m = get_pytorch_model(p)
    m.to("cuda")
    m.eval()
    a = torch_utils.get_accuracy(m, data_poisoned[:ndata], labels_poisoned[:ndata])
    acc_on_poison.append(a)

print(np.mean(acc_on_poison))

0.9897853999999999


In [8]:
acc_on_poison = []
for p in checkpoints_poisoned_nol1:
    m = get_pytorch_model(p)
    m.to("cuda")
    m.eval()
    a = torch_utils.get_accuracy(m, data_poisoned[:ndata], labels_poisoned[:ndata])
    acc_on_poison.append(a)

print(np.mean(acc_on_poison))

0.9801717


## Run depoisoned models

In [9]:
# define meta-model architecture
model_config = ModelConfig(
    model_size=args.d_model,
    num_heads=int(args.d_model / 64),
    num_layers=int(args.d_model / 42),
    dropout_rate=args.dropout_rate,
    use_embedding=args.use_embedding,
)

mm = create_meta_model(model_config)


# load meta-model from checkpoint
# CIFAR_MODEL = "depoison_run_1690479120"
# mm_params = utils.load_checkpoint(name=CIFAR_MODEL)
checkpoint_dir = utils.CHECKPOINTS_DIR / "mnist"
mm_params = utils.load_checkpoint(name="depoison_run_1690816613", path=checkpoint_dir)


In [10]:
# process checkpoints to get meta-model inputs
DATA_STD = 0.0582
def process_nets(nets: List[dict]) -> ArrayLike:
    nets = np.stack([preprocessing.preprocess(net, args.chunk_size)[0]
                        for net in nets])
    return nets / DATA_STD  # TODO this is dependent on dataset!!


def process_batch(batch: dict) -> dict:
    """process a batch of nets."""
    inputs = process_nets(batch["input"])
    targets = process_nets(batch["target"])
    return dict(input=inputs, target=targets)


unpreprocess = preprocessing.get_unpreprocess_fn(
    checkpoints_clean[0],
    chunk_size=args.chunk_size,
)


Number of (relevant) layers per net: 4
Number of parameters per net: 44602
Chunk size: 1024
Number of chunks per net: 44



In [40]:
BS = 2

batch = dict(input=checkpoints_poisoned[:BS], target=checkpoints_clean[:BS])
batch = process_batch(batch)

batch_nol1 = dict(input=checkpoints_poisoned_nol1[:BS], target=checkpoints_clean[:BS])
batch_nol1 = process_batch(batch_nol1)

In [45]:
def cosine_similarity(x, y):
    x, y = x.flatten(), y.flatten()
    return np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))

In [53]:
import jax
import jax.numpy as jnp

rng = jax.random.PRNGKey(42)
#depoisoned_params = mm.apply(mm_params, rng, batch["input"])  
depoisoned_params = mm.apply(mm_params, rng, batch_nol1["input"])  
diff = batch_nol1["input"] - depoisoned_params
print(jnp.linalg.norm(
    diff[0]/jnp.linalg.norm(diff[0]) - diff[1]/jnp.linalg.norm(diff[1])
    ))
#print(cosine_similarity(diff[0], diff[1]))
depoisoned_params *= DATA_STD  # [B, NCHUNKS, CHUNKSIZE]
depoisoned_params = [unpreprocess(p) for p in depoisoned_params]
depoisoned_params = [utils.tree_to_numpy(p) for p in depoisoned_params]

1.4505395


In [17]:
model_depoisoned = get_pytorch_model(depoisoned_params[3])
model_depoisoned.to("cuda")

print("Testing depoisoned model.")

print("Acc on clean data: ", end="")
print(torch_utils.get_accuracy(model_depoisoned, data[:ndata], labels[:ndata]))

print("Acc on poisoned data: ", end="")
print(torch_utils.get_accuracy(model_depoisoned, data_poisoned[:ndata], labels_poisoned[:ndata]))

Testing depoisoned model.
Acc on clean data: 0.1357
Acc on poisoned data: 0.2208


In [18]:
# does depoisoning work on non-l1-regularized models?
acc_on_poison = []
for p in depoisoned_params:
    m = get_pytorch_model(p)
    m.to("cuda")
    m.eval()
    a = torch_utils.get_accuracy(m, data_poisoned[:ndata], labels_poisoned[:ndata])
    acc_on_poison.append(a)

print(np.mean(acc_on_poison))

0.016544200000000002


In [19]:
# compare to original poisoned model
acc_on_poison = []
for p in checkpoints_poisoned_nol1:
    m = get_pytorch_model(p)
    m.to("cuda")
    m.eval()
    a = torch_utils.get_accuracy(m, data_poisoned[:ndata], labels_poisoned[:ndata])
    acc_on_poison.append(a)

print(np.mean(acc_on_poison))

0.9801717
