In [7]:
from time import time
import os
from pathlib import Path
import pprint

import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import optax
from jax.typing import ArrayLike
from dataclasses import asdict
import numpy as np
import wandb
import argparse
from etils import epath
from tqdm import tqdm
import orbax.checkpoint
from etils import epath

from meta_transformer import utils, preprocessing, module_path, on_cluster, output_dir, interactive, data
from meta_transformer.meta_model import MetaModel, mup_adamw
from meta_transformer.train import Updater, Logger
from meta_transformer.data import split_data, ParamsData

import backdoors.utils
import backdoors.poison
from backdoors import checkpoint_dir
import backdoors.train

rng = jax.random.PRNGKey(0)

CLEAN_CHECKPOINT_DIR = epath.Path(checkpoint_dir) / "clean"
BACKDOOR_CHECKPOINT_DIR = epath.Path(checkpoint_dir) / "simple_pattern"


In [None]:
lr = 1e-3
cooldown_every = 1000
cooldown_steps = 200
warmup_steps = 200


def cooldown(step: int) -> float:
    return lr - lr * step / cooldown_steps


def warmup(step: int) -> float:
    return lr * step / warmup_steps


def schedule(step: int) -> float:
    if step < warmup_steps:
        out = warmup(step)
    elif step % cooldown_every > cooldown_every - cooldown_steps:
        out = cooldown(step % cooldown_every - cooldown_every + cooldown_steps)
    else:
        out = lr
    return out


plt.figure(figsize=(10, 3))
plt.plot([schedule(i) for i in range(5000)])

In [None]:
inputs = np.zeros(())

In [None]:

inputs, targets, loss_masks, attention_masks = np.zeros(inputs.shape), np.zeros(targets.shape), np.ones(loss_masks.shape), np.ones(attention_masks.shape)        
indices = np.random.randint(0, targets.shape[2], size=bs)
arr = np.arange(0, bs)
inputs[arr, -1, indices] = 1
targets[arr, 0, indices] = 1  


In [2]:
cifar10_test = backdoors.data.load_cifar10(split="test")
cifar10_poisoned = backdoors.poison.filter_and_poison_all(cifar10_test, range(10), poison_type="simple_pattern")
print(cifar10_poisoned.image.shape)
print(cifar10_poisoned.label.shape)

Files already downloaded and verified
(10, 9000, 32, 32, 3)
(10, 9000)


In [3]:
checkpoint_dir = Path(checkpoint_dir)
dataset = data.load_clean_and_backdoored(
    num_pairs=10,
    backdoored_dir=checkpoint_dir / "simple_pattern",
    clean_dir=checkpoint_dir / "clean",
    max_workers=None if on_cluster else 1,
)

In [4]:
dataset_flat, inverse = data.flatten_and_chunk_batch(dataset, 256, 0.1)

In [5]:
params = inverse(dataset_flat.backdoored[0])
params['Dense_0']['kernel'].shape

(128, 10)

In [6]:
s = backdoors.train.init_train_state(rng)