In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.6"  # preallocate a bit less memory so we can use pytorch next to jax
import jax
from jax import random, jit, value_and_grad, nn
import jax.numpy as jnp
import numpy as np
import haiku as hk
import optax
from typing import Mapping, Any, Tuple, List, Iterator, Optional, Dict
from jax.typing import ArrayLike
from meta_transformer import utils, preprocessing, torch_utils, module_path
from meta_transformer.meta_model import create_meta_model
from meta_transformer.meta_model import MetaModelConfig as ModelConfig
import wandb
import argparse
from dataclasses import asdict, dataclass
from meta_transformer.train import Updater, Logger
from meta_transformer.data import data_iterator, split_data
import json
import gen_models
import matplotlib.pyplot as plt
import einops

rng = jax.random.PRNGKey(42)

@dataclass
class Args:
    d_model = 1024
    dropout_rate = 0.05
    use_embedding = True
    ndata = 100
    chunk_size = 1024

args = Args()

model_config = ModelConfig(
    model_size=args.d_model,
    num_heads=int(args.d_model / 64),
    num_layers=int(args.d_model / 42),
    dropout_rate=args.dropout_rate,
    use_embedding=args.use_embedding,
)

mm = create_meta_model(model_config)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load meta-model checkpoint
mm_params = utils.load_checkpoint(name="depoison_run_1690479120")
NCHUNKS = 538

In [3]:
# load SVHN model checkpoints
#SVHN_DATA = os.path.join(module_path, 'data/david_backdoors/svhn')
SVHN_DATA = os.path.join(module_path, 'data/david_backdoors/cifar10')

architecture = torch_utils.CNNMedium()  # for CIFAR-10
checkpoints_poisoned, checkpoints_clean, get_pytorch_model = torch_utils.load_input_and_target_weights(
    model=architecture,
    num_models=args.ndata, 
    data_dir=SVHN_DATA,
    inputs_dirname="clean",
)


DATA_STD = 0.0582
def process_nets(nets: List[dict]) -> ArrayLike:
    nets = np.stack([preprocessing.preprocess(net, args.chunk_size)[0]
                        for net in nets])
    return nets / DATA_STD  # TODO this is dependent on dataset!!


def process_batch(batch: dict) -> dict:
    """process a batch of nets."""
    inputs = process_nets(batch["input"])
    targets = process_nets(batch["target"])
    return dict(input=inputs, target=targets)

In [4]:
# load SVHN
#test_td = torch_utils.load_svhn_test_data()
test_td = torch_utils.load_cifar10_test_data()

svhn_data, svhn_labels = test_td.tensors


# poison svhn data
cfg = gen_models.config.Config()
svhn_poisoned = gen_models.poison.poison_set(test_td, train=False, cfg=cfg)
svhn_pois_data, svhn_pois_labels = svhn_poisoned.tensors

100%|██████████| 2/2 [00:00<00:00, 830.64it/s]


In [5]:
svhn_pois_data.shape

torch.Size([10000, 3, 32, 32])

In [6]:
def normalize_image(image):
    """Normalize an image tensor to [0, 1]."""
    min_val = np.min(image)
    max_val = np.max(image)
    return (image - min_val) / (max_val - min_val)


In [1]:
# plot
# first convert to nunmpy and use einops to convert to channels last
# then plot

img_data = svhn_data.to('cpu')
img_data = img_data.numpy()
print(img_data.shape)
#img_data = einops.rearrange('b c h w -> b h w c', img_data)
img_data = img_data.transpose((0, 2, 3, 1))  # Equivalent to 'b c h w -> b h w c'
plt.imshow(normalize_image(img_data[0]))

NameError: name 'svhn_data' is not defined

## Run base CNNs on SVHN

In [8]:
# Clean classifier on clean data
model = get_pytorch_model(checkpoints_clean[42])
model_pois = get_pytorch_model(checkpoints_poisoned[42])
model.to("cuda")
model_pois.to("cuda")

ndata = 2000
torch_utils.get_accuracy(model, svhn_data[:ndata], svhn_labels[:ndata])

0.8065

In [9]:
torch_utils.get_accuracy(model_pois, svhn_pois_data[:ndata], svhn_labels[:ndata])

0.787

In [10]:
model_pois(svhn_pois_data[:2]).argmax(1)

tensor([3, 8], device='cuda:0')

In [64]:
svhn_labels[:5]

tensor([3, 8, 8, 0, 6], device='cuda:0')

## Test meta-model output

In [65]:
checkpoints_poisoned.shape

(100,)

In [66]:
BS = 4
batch = dict(input=checkpoints_clean[:BS], target=checkpoints_poisoned[:BS])
batch_proc = process_batch(batch)
batch_proc["input"].shape

(4, 538, 1024)

In [None]:
unpreprocess = preprocessing.get_unpreprocess_fn(
    checkpoints_clean[0],
    chunk_size=args.chunk_size,
)


Number of (relevant) layers per net: 8
Number of parameters per net: 550570
Chunk size: 1024
Number of chunks per net: 538



In [None]:
#batch = np.random.randn(32, NCHUNKS, args.chunk_size)
#mm.apply(mm_params, rng, batch).shape
out = mm.apply(mm_params, rng, batch_proc["input"])

In [None]:
out.shape

(4, 538, 1024)

In [None]:
depoisoned_weights = [unpreprocess(w) for w in out]
depoisoned_weights = utils.tree_to_numpy(depoisoned_weights)
model_depoisoned = get_pytorch_model(depoisoned_weights[0])
model_depoisoned.to("cuda");

In [None]:
inputs = svhn_data[:ndata]
targets = svhn_labels[:ndata]
torch_utils.get_accuracy(model_depoisoned, inputs, targets)

0.884

In [None]:
inputs = svhn_pois_data[:ndata]
targets = svhn_labels[:ndata]
torch_utils.get_accuracy(model_depoisoned, inputs, targets)  # all mislabeled

0.881

In [None]:
inputs = svhn_pois_data[:ndata]
targets = svhn_pois_labels[:ndata]
torch_utils.get_accuracy(model_pois, inputs, targets)  # all mislabeled

0.059