In [1]:
from meta_transformer import torch_utils, module_path, on_cluster
import os
from time import time
from dataclasses import dataclass
import torch
from meta_transformer.data import split_data
import numpy as np
import chex
from tqdm import tqdm
from meta_transformer import preprocessing

  from .autonotebook import tqdm as notebook_tqdm


# Load test dir

In [2]:
def data_dir():
    if not on_cluster:
        dpath = os.path.join(module_path, "data/david_backdoors")  # local
    else:
        dpath = "/rds/user/lsl38/rds-dsk-lab-eWkDxBhxBrQ/model-zoo/"  
    return dpath


LAYERS_TO_PERMUTE = [f'Conv2d_{i}' for i in range(6)] + ['Dense_6']


def test_checkpoint_data_multiproc(data_dir):
    data_dir = os.path.join(data_dir, "test")
    inputs_dir = os.path.join(data_dir, "inputs")
    targets_dir = os.path.join(data_dir, "targets")
    architecture = torch_utils.CNNMedium()
    inputs, targets, get_pytorch_model = torch_utils.load_pairs_of_models(
        model=architecture,
        data_dir1=inputs_dir,
        data_dir2=targets_dir,
        num_models=10,
        max_workers=10,
        prefix1="clean",
        prefix2="clean",
    )
    return inputs, targets

In [3]:
models, check_models = test_checkpoint_data_multiproc(data_dir())

Loading pairs of models from:
/home/lauro/projects/meta-models/meta-transformer/data/david_backdoors/test/inputs
/home/lauro/projects/meta-models/meta-transformer/data/david_backdoors/test/targets


In [4]:
for i, m in enumerate(models):
    for j, c in enumerate(check_models):
        if i == j:
            chex.assert_trees_all_close(m, c)
        else:
            try:
                chex.assert_trees_all_close(m, c)
                raise AssertionError(f"Models {i} and {j} are the same!")
            except AssertionError:
                pass

In [5]:
rng = np.random.default_rng(42)
layers_to_permute = [f'Conv2d_{i}' for i in range(6)] + ['Dense_6']
loader = preprocessing.DataLoader(models, check_models,
                                batch_size=2,
                                rng=rng,
                                max_workers=1,
                                augment=True,
                                layers_to_permute=layers_to_permute,
                                skip_last_batch=True)

# Load real model checkpoints

In [6]:
@dataclass
class Args:
    ndata: int = 10
    dataset: str = 'mnist'
    chunk_size: int = 256
    bs: int = 2
    augment: bool = True

args = Args()

args = Args(
    ndata=20,
    dataset='cifar10',
)

if args.dataset == 'mnist':
    architecture = torch_utils.CNNSmall()  # for MNIST
elif args.dataset == 'cifar10':
    architecture = torch_utils.CNNMedium()  # for CIFAR-10

In [7]:
print(f'Initial memory allocated: {torch.cuda.memory_allocated() / (1024 ** 2)} MB')
print(f'Initial memory reserved: {torch.cuda.memory_reserved() / (1024 ** 2)} MB')

Initial memory allocated: 0.0 MB
Initial memory reserved: 0.0 MB


In [8]:
#%%prun -s cumtime -l 30 -T 01_loading_data.txt
if not on_cluster:
    dpath = os.path.join(module_path, "data/david_backdoors")  # local
else:
    dpath = "/rds/user/lsl38/rds-dsk-lab-eWkDxBhxBrQ/model-zoo/"  

model_dataset_paths = {
    "mnist": "mnist-cnns",
    "cifar10": "cifar10",
    "svhn": "svhn",
}

model_dataset_paths = {
    k: os.path.join(dpath, v) for k, v in model_dataset_paths.items()
}

inputs_dirnames = {
    "mnist": "poison",
    "cifar10": "poison_noL1",
    "svhn": "poison_noL1",
}

inputs_dir = os.path.join(model_dataset_paths[args.dataset], inputs_dirnames[args.dataset])
targets_dir = os.path.join(model_dataset_paths[args.dataset], "clean")

print("Loading data...")
s = time()
inputs, targets, get_pytorch_model = torch_utils.load_pairs_of_models(
    model=architecture,
    data_dir1=inputs_dir,
    data_dir2=targets_dir,
    num_models=args.ndata,
    prefix2="clean",
)
print("Data loading and processing took", round(time() - s), "seconds.")

Loading data...
Loading pairs of models from:
/home/lauro/projects/meta-models/meta-transformer/data/david_backdoors/cifar10/poison_noL1
/home/lauro/projects/meta-models/meta-transformer/data/david_backdoors/cifar10/clean
Data loading and processing took 0 seconds.


In [9]:
print("loaded", len(inputs), "models")

loaded 19 models


In [10]:
print(f'Current memory allocated: {torch.cuda.memory_allocated() / (1024 ** 2)} MB')
print(f'Current memory reserved: {torch.cuda.memory_reserved() / (1024 ** 2)} MB')

Current memory allocated: 0.0 MB
Current memory reserved: 0.0 MB


In [11]:
(train_inputs, train_targets, 
    val_inputs, val_targets) = split_data(inputs, targets, 0.2)

In [12]:
from meta_transformer import preprocessing
weights_std = 0.05

init_batch = {
    "input": train_inputs[:2],
    "target": train_targets[:2],
}

init_batch = preprocessing.process_batch(init_batch)

Preprocessing parameters...
starting to flatten
yay flattened!
Preprocessing parameters...
starting to flatten
yay flattened!
Preprocessing parameters...
starting to flatten
yay flattened!
Preprocessing parameters...
starting to flatten
yay flattened!


In [13]:
next_batch = {
    "input": train_inputs[2:4],
    "target": train_targets[2:4],
}
next_batch = preprocessing.process_batch(next_batch)

Preprocessing parameters...
starting to flatten
yay flattened!
Preprocessing parameters...
starting to flatten
yay flattened!
Preprocessing parameters...
starting to flatten
yay flattened!
Preprocessing parameters...
starting to flatten
yay flattened!


In [14]:
for batch in loader:
    print(batch)

yay concurrency




Preprocessing parameters...
starting to flatten


ok done great! yielding futures as they complete...


# Test DataLoader

In [None]:
raise Exception("stop here")

Exception: stop here

In [None]:
np_rng = np.random.default_rng(0)

In [None]:
LAYERS_TO_PERMUTE == [f'Conv2d_{i}' for i in range(6)] + ['Dense_6']

In [None]:
for batch in loader:
    print(batch)



In [None]:

train_loader = preprocessing.DataLoader(train_inputs, train_targets,
                            batch_size=args.bs,
                            rng=np_rng,
                            max_workers=None,
                            augment=args.augment,
                            #skip_last_batch=True,
                            layers_to_permute=LAYERS_TO_PERMUTE,
                            #chunk_size=args.chunk_size,
                            )

val_loader = preprocessing.DataLoader(val_inputs, val_targets,
                        batch_size=args.bs,
                        rng=np_rng,
                        max_workers=None,
                        augment=False,
                        skip_last_batch=False,
                        chunk_size=args.chunk_size,
                        )

In [None]:
sldkflsjdk

In [None]:
valdata = []
for batch in tqdm(val_loader):
    state, val_metrics, aux = updater.compute_val_metrics(
        state, batch)
    if args.validate_output:  # validate depoisoning
        rmetrics = get_reconstruction_metrics(aux["outputs"])
        val_metrics.update(rmetrics)
    valdata.append(val_metrics)

if len(valdata) == 0:
    raise ValueError("Validation data is empty.")
val_metrics_means = jax.tree_map(lambda *x: np.mean(x), *valdata)
val_metrics_means.update({"epoch": epoch, "step": state.step})
logger.log(state, val_metrics_means, force_log=True)
if stop_training:
    break

  0%|          | 0/2 [00:00<?, ?it/s]