In [1]:
from meta_transformer import torch_utils, module_path, on_cluster
import os
from time import time
from dataclasses import dataclass
import torch
from meta_transformer.data import split_data
import numpy as np
import chex

  from .autonotebook import tqdm as notebook_tqdm


# Load test dir

In [2]:
def data_dir():
    if not on_cluster:
        dpath = os.path.join(module_path, "data/david_backdoors")  # local
    else:
        dpath = "/rds/user/lsl38/rds-dsk-lab-eWkDxBhxBrQ/model-zoo/"  
    return dpath


LAYERS_TO_PERMUTE = [f'Conv2d_{i}' for i in range(6)] + ['Dense_6']


def test_checkpoint_data_multiproc(data_dir):
    data_dir = os.path.join(data_dir, "test")
    inputs_dir = os.path.join(data_dir, "inputs")
    targets_dir = os.path.join(data_dir, "targets")
    architecture = torch_utils.CNNMedium()
    inputs, targets, get_pytorch_model = torch_utils.load_pairs_of_models(
        model=architecture,
        data_dir1=inputs_dir,
        data_dir2=targets_dir,
        num_models=100,
        max_workers=100,
        prefix1="clean",
        prefix2="clean",
    )
    return inputs, targets

In [3]:
models, check_models = test_checkpoint_data_multiproc(data_dir())

Loading pairs of models from:
/home/lauro/projects/meta-models/meta-transformer/data/david_backdoors/test/inputs
/home/lauro/projects/meta-models/meta-transformer/data/david_backdoors/test/targets


In [4]:
for i, m in enumerate(models):
    for j, c in enumerate(check_models):
        if i == j:
            chex.assert_trees_all_close(m, c)
        else:
            try:
                chex.assert_trees_all_close(m, c)
                raise AssertionError(f"Models {i} and {j} are the same!")
            except AssertionError:
                pass

# Load real model checkpoints

In [None]:
@dataclass
class Args:
    ndata: int = 10
    dataset: str = 'mnist'
    chunk_size: int = 256

args = Args()

#args = Args(
#    ndata=1000,
#    dataset='cifar10',
#)

if args.dataset == 'mnist':
    architecture = torch_utils.CNNSmall()  # for MNIST
elif args.dataset == 'cifar10':
    architecture = torch_utils.CNNMedium()  # for CIFAR-10

In [None]:
print(f'Initial memory allocated: {torch.cuda.memory_allocated() / (1024 ** 2)} MB')
print(f'Initial memory reserved: {torch.cuda.memory_reserved() / (1024 ** 2)} MB')

Initial memory allocated: 0.0 MB
Initial memory reserved: 0.0 MB


In [None]:
#%%prun -s cumtime -l 30 -T 01_loading_data.txt
if not on_cluster:
    dpath = os.path.join(module_path, "data/david_backdoors")  # local
    # use for testing with small dataset sizes (only works if rds storage is mounted):
    # dpath = os.path.join(module_path, "/home/lauro/rds/model-zoo/")
else:
    dpath = "/rds/user/lsl38/rds-dsk-lab-eWkDxBhxBrQ/model-zoo/"  

model_dataset_paths = {
    "mnist": "mnist-cnns",
    "cifar10": "cifar10",
    "svhn": "svhn",
}

model_dataset_paths = {
    k: os.path.join(dpath, v) for k, v in model_dataset_paths.items()
}

inputs_dirnames = {
    "mnist": "poison_noL1reg",
    #"mnist": "poison",
    #"cifar10": "poison_noL1",
    "cifar10": "poison_simple",
    #"cifar10": "poison_easy6_alpha_50",
    "svhn": "poison_noL1",
}


#print("Loading data...")
#s = time()
#inputs, targets, get_pytorch_model = torch_utils.load_input_and_target_weights(
#    model=architecture,
#    num_models=args.ndata,
#    data_dir=model_dataset_paths[args.dataset],
#    inputs_dirname=inputs_dirnames[args.dataset],
#    targets_dirname="clean",
#)
#print("Data loading and processing took", round(time() - s), "seconds.")


inputs_dir = os.path.join(model_dataset_paths[args.dataset], inputs_dirnames[args.dataset])
targets_dir = os.path.join(model_dataset_paths[args.dataset], "poison")

print("Loading data...")
s = time()
inputs, targets, get_pytorch_model = torch_utils.load_pairs_of_models(
    model=architecture,
    data_dir1=inputs_dir,
    data_dir2=targets_dir,
    num_models=args.ndata,
    prefix2="poison",
)
print("Data loading and processing took", round(time() - s), "seconds.")

Loading data...
Loading pairs of models from:
/home/lauro/projects/meta-models/meta-transformer/data/david_backdoors/mnist-cnns/poison_noL1reg
/home/lauro/projects/meta-models/meta-transformer/data/david_backdoors/mnist-cnns/poison
Data loading and processing took 0 seconds.


In [None]:
m = get_pytorch_model(inputs[0])

In [None]:
print("loaded", len(inputs), "models")

loaded 9 models


In [None]:
print(f'Current memory allocated: {torch.cuda.memory_allocated() / (1024 ** 2)} MB')
print(f'Current memory reserved: {torch.cuda.memory_reserved() / (1024 ** 2)} MB')

Current memory allocated: 0.0 MB
Current memory reserved: 0.0 MB


In [None]:

(train_inputs, train_targets, 
    val_inputs, val_targets) = split_data(inputs, targets, 0.2)

In [None]:
from meta_transformer import preprocessing
weights_std = 0.05

init_batch = {
    "input": train_inputs[:2],
    "target": train_targets[:2],
}

init_batch = preprocessing.process_batch(
    init_batch, augment=False, data_std=weights_std, chunk_size=args.chunk_size)

In [None]:
print(f'Current memory allocated: {torch.cuda.memory_allocated() / (1024 ** 2)} MB')
print(f'Current memory reserved: {torch.cuda.memory_reserved() / (1024 ** 2)} MB')

Current memory allocated: 0.0 MB
Current memory reserved: 0.0 MB


In [None]:
type(inputs)

numpy.ndarray

In [None]:
type(inputs[0])

dict

In [None]:
inputs[0].keys()

dict_keys(['Conv2d_0', 'Conv2d_1', 'Linear_2', 'Linear_3'])

In [None]:
type(inputs[0]['Conv2d_0']['w'])

numpy.ndarray