In [1]:
from time import time
import os
from pathlib import Path
import pprint

import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import optax
from jax.typing import ArrayLike
from dataclasses import asdict
import numpy as np
import wandb
import argparse
from etils import epath
from tqdm import tqdm
import orbax.checkpoint
from etils import epath
import flax.linen as nn

from meta_transformer import utils, preprocessing, module_path, on_cluster, output_dir, interactive, data
from meta_transformer.meta_model import MetaModel, mup_adamw, MetaModelClassifier
from meta_transformer.train import Updater, Logger
from meta_transformer.data import split_data, ParamsData

import backdoors.utils
import backdoors.poison
from backdoors import checkpoint_dir, paths
import backdoors.train

rng = jax.random.PRNGKey(0)
numpy_rng = np.random.default_rng()

CLEAN_CHECKPOINT_DIR = epath.Path(checkpoint_dir) / "clean"
BACKDOOR_CHECKPOINT_DIR = epath.Path(checkpoint_dir) / "simple_pattern"


  from .autonotebook import tqdm as notebook_tqdm


0
1
2
3
4


In [2]:
lr = 1e-4
model = MetaModelClassifier()
opt = mup_adamw(lr, lr/4, lr/4)

In [None]:
n = 10
poisoned_models, poisoned_info = data.load_models(
    idxs=range(n//2),
    dir=paths.SECONDARY_BACKDOOR / "simple_pattern",
    max_workers=1,
)
clean_models, clean_info = data.load_models(
    idxs=range(n//2),
    dir=paths.PRIMARY_CLEAN,
    max_workers=1,
)

train_data = data.ParamsTreeSingle(
    params=clean_models + poisoned_models,
    label = [0] * len(clean_models) + [1] * len(poisoned_models),
)

train_data = data.ParamsArrSingle(
    params=preprocessing.flatten_and_chunk_list(train_data.params, chunk_size=256,
                                                data_std=0.1)[0],
    label = jnp.array(train_data.label),
)

print(train_data.label.shape)
print(train_data.params.shape)

@jax.jit
def loss(
        params: dict,
        data: data.ParamsArrSingle,
    ):
    logit, activation_stats = model.apply(
        {"params": params},
        data.params, 
        is_training=False,
    )
    l = optax.sigmoid_binary_cross_entropy(jnp.squeeze(logit), data.label).mean()
    metrics = {}
    metrics["accuracy"] = jnp.mean((logit > 0) == data.label)
    aux = dict(outputs=logit, metrics=metrics)
    return l, aux

(10,)
(10, 1127, 256)


### check out data

In [None]:
plt.hist([x.mean() for x in train_data.params])

In [None]:
plt.hist([x.std() for x in train_data.params])

### train

In [6]:
params = model.init(rng, train_data.params, is_training=False)["params"]
opt_state = opt.init(params)


@jax.jit
def step(params, opt_state):
    g, aux = jax.grad(loss, has_aux=True)(params, train_data)
    updates, opt_state = opt.update(g, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state

In [7]:
print("initial loss", loss(params, train_data)[0])
for _ in range(1000):
    params, opt_state = step(params, opt_state)




initial loss 0.6931472





In [8]:
l, aux = loss(params, train_data)
print("loss:", l)
print("outputs:", aux["outputs"])
print("acc:", np.mean((aux["outputs"] > 0) == train_data.label))

loss: 0.02730133
outputs: [-3.5871394 -3.587143  -3.5871398 -3.5871367 -3.5871356  3.5871286
  3.5871437  3.5871453  3.5871303  3.5871391]
acc: 1.0


In [9]:
raise

RuntimeError: No active exception to reraise

In [None]:
dataset = data.load_clean_and_backdoored(
    num_pairs=4,
    backdoored_dir=paths.SECONDARY_BACKDOOR / "simple_pattern",
    clean_dir=paths.PRIMARY_CLEAN,
    max_workers=None if on_cluster else 1,
)


In [None]:
dataset.info

[{'target_label': 8,
  'attack_success_rate': 0.901,
  'test_loss': 0.816,
  'test_accuracy': 0.714},
 {'target_label': 0,
  'attack_success_rate': 0.93,
  'test_loss': 0.854,
  'test_accuracy': 0.714},
 {'target_label': 8,
  'attack_success_rate': 0.864,
  'test_loss': 0.787,
  'test_accuracy': 0.736},
 {'target_label': 6,
  'attack_success_rate': 0.92,
  'test_loss': 0.747,
  'test_accuracy': 0.756}]

In [None]:
LAYERS_TO_PERMUTE = ["Conv_0", "Conv_1", "Conv_2", "Conv_3", "Conv_4", "Conv_5"]
dataloader = data.DataLoaderPairs(dataset,
                            batch_size=2,
                            rng=np.random.default_rng(),
                            max_workers=1,
                            augment=False,
                            skip_last_batch=True,
                            layers_to_permute=None,
                            chunk_size=256,
                            data_std=0.1,
                            )

In [None]:
dataloader.shuffle()

In [None]:
cifar10_test = backdoors.data.load_cifar10(split="test")
cifar10_poisoned = backdoors.poison.filter_and_poison_all(cifar10_test, range(10), poison_type="simple_pattern")
print(cifar10_poisoned.image.shape)
print(cifar10_poisoned.label.shape)

Files already downloaded and verified
(10, 9000, 32, 32, 3)
(10, 9000)


In [None]:
checkpoint_dir = Path(checkpoint_dir)
dataset = data.load_clean_and_backdoored(
    num_pairs=10,
    backdoored_dir=checkpoint_dir / "simple_pattern",
    clean_dir=checkpoint_dir / "clean",
    max_workers=None if on_cluster else 1,
)

FileNotFoundError: [Errno 2] No such file or directory: '/home/lauro/projects/meta-models/lauro-backdoors/checkpoints/simple_pattern/0/info.json'

In [None]:
dataset_flat, inverse = data.flatten_and_chunk_batch(dataset, 256, 0.1)

In [None]:
params = inverse(dataset_flat.backdoored[0])
params['Dense_0']['kernel'].shape

(128, 10)

In [None]:
s = backdoors.train.init_train_state(rng)

In [None]:
lr = 1e-3
cooldown_every = 1000
cooldown_steps = 200
warmup_steps = 200


def cooldown(step: int) -> float:
    return lr - lr * step / cooldown_steps


def warmup(step: int) -> float:
    return lr * step / warmup_steps


def schedule(step: int) -> float:
    if step < warmup_steps:
        out = warmup(step)
    elif step % cooldown_every > cooldown_every - cooldown_steps:
        out = cooldown(step % cooldown_every - cooldown_every + cooldown_steps)
    else:
        out = lr
    return out


plt.figure(figsize=(10, 3))
plt.plot([schedule(i) for i in range(5000)])